2015-07-01 19:43:18 -04:00
|
|
|
# -*- encoding: utf-8 -*-
|
2015-02-03 19:01:56 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
"""
|
|
|
|
Unit tests for pyspark.sql; additional tests are implemented as doctests in
|
|
|
|
individual modules.
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
import sys
|
2016-06-28 10:54:44 -04:00
|
|
|
import subprocess
|
2015-02-03 19:01:56 -05:00
|
|
|
import pydoc
|
|
|
|
import shutil
|
|
|
|
import tempfile
|
2015-02-27 23:07:17 -05:00
|
|
|
import pickle
|
2015-04-01 20:23:57 -04:00
|
|
|
import functools
|
2015-06-11 04:00:41 -04:00
|
|
|
import time
|
2015-04-21 03:08:18 -04:00
|
|
|
import datetime
|
[SPARK-16542][SQL][PYSPARK] Fix bugs about types that result an array of null when creating DataFrame using python
## What changes were proposed in this pull request?
This is the reopen of https://github.com/apache/spark/pull/14198, with merge conflicts resolved.
ueshin Could you please take a look at my code?
Fix bugs about types that result an array of null when creating DataFrame using python.
Python's array.array have richer type than python itself, e.g. we can have `array('f',[1,2,3])` and `array('d',[1,2,3])`. Codes in spark-sql and pyspark didn't take this into consideration which might cause a problem that you get an array of null values when you have `array('f')` in your rows.
A simple code to reproduce this bug is:
```
from pyspark import SparkContext
from pyspark.sql import SQLContext,Row,DataFrame
from array import array
sc = SparkContext()
sqlContext = SQLContext(sc)
row1 = Row(floatarray=array('f',[1,2,3]), doublearray=array('d',[1,2,3]))
rows = sc.parallelize([ row1 ])
df = sqlContext.createDataFrame(rows)
df.show()
```
which have output
```
+---------------+------------------+
| doublearray| floatarray|
+---------------+------------------+
|[1.0, 2.0, 3.0]|[null, null, null]|
+---------------+------------------+
```
## How was this patch tested?
New test case added
Author: Xiang Gao <qasdfgtyuiop@gmail.com>
Author: Gao, Xiang <qasdfgtyuiop@gmail.com>
Author: Takuya UESHIN <ueshin@databricks.com>
Closes #18444 from zasdfgbnm/fix_array_infer.
2017-07-19 23:46:06 -04:00
|
|
|
import array
|
|
|
|
import ctypes
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
import warnings
|
2015-02-17 18:44:37 -05:00
|
|
|
import py4j
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
from contextlib import contextmanager
|
[SPARK-16542][SQL][PYSPARK] Fix bugs about types that result an array of null when creating DataFrame using python
## What changes were proposed in this pull request?
This is the reopen of https://github.com/apache/spark/pull/14198, with merge conflicts resolved.
ueshin Could you please take a look at my code?
Fix bugs about types that result an array of null when creating DataFrame using python.
Python's array.array have richer type than python itself, e.g. we can have `array('f',[1,2,3])` and `array('d',[1,2,3])`. Codes in spark-sql and pyspark didn't take this into consideration which might cause a problem that you get an array of null values when you have `array('f')` in your rows.
A simple code to reproduce this bug is:
```
from pyspark import SparkContext
from pyspark.sql import SQLContext,Row,DataFrame
from array import array
sc = SparkContext()
sqlContext = SQLContext(sc)
row1 = Row(floatarray=array('f',[1,2,3]), doublearray=array('d',[1,2,3]))
rows = sc.parallelize([ row1 ])
df = sqlContext.createDataFrame(rows)
df.show()
```
which have output
```
+---------------+------------------+
| doublearray| floatarray|
+---------------+------------------+
|[1.0, 2.0, 3.0]|[null, null, null]|
+---------------+------------------+
```
## How was this patch tested?
New test case added
Author: Xiang Gao <qasdfgtyuiop@gmail.com>
Author: Gao, Xiang <qasdfgtyuiop@gmail.com>
Author: Takuya UESHIN <ueshin@databricks.com>
Closes #18444 from zasdfgbnm/fix_array_infer.
2017-07-19 23:46:06 -04:00
|
|
|
|
2015-10-22 18:27:11 -04:00
|
|
|
try:
|
|
|
|
import xmlrunner
|
|
|
|
except ImportError:
|
|
|
|
xmlrunner = None
|
2015-02-17 18:44:37 -05:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
if sys.version_info[:2] <= (2, 6):
|
|
|
|
try:
|
|
|
|
import unittest2 as unittest
|
|
|
|
except ImportError:
|
|
|
|
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
|
|
|
|
sys.exit(1)
|
|
|
|
else:
|
|
|
|
import unittest
|
|
|
|
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
from pyspark.util import _exception_message
|
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
_pandas_requirement_message = None
|
2017-06-22 04:22:02 -04:00
|
|
|
try:
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
from pyspark.sql.utils import require_minimum_pandas_version
|
|
|
|
require_minimum_pandas_version()
|
|
|
|
except ImportError as e:
|
|
|
|
# If Pandas version requirement is not satisfied, skip related tests.
|
|
|
|
_pandas_requirement_message = _exception_message(e)
|
|
|
|
|
|
|
|
_pyarrow_requirement_message = None
|
|
|
|
try:
|
|
|
|
from pyspark.sql.utils import require_minimum_pyarrow_version
|
|
|
|
require_minimum_pyarrow_version()
|
|
|
|
except ImportError as e:
|
|
|
|
# If Arrow version requirement is not satisfied, skip related tests.
|
|
|
|
_pyarrow_requirement_message = _exception_message(e)
|
|
|
|
|
|
|
|
_have_pandas = _pandas_requirement_message is None
|
|
|
|
_have_pyarrow = _pyarrow_requirement_message is None
|
2017-06-22 04:22:02 -04:00
|
|
|
|
2017-01-12 07:53:31 -05:00
|
|
|
from pyspark import SparkContext
|
2017-01-13 05:35:12 -05:00
|
|
|
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
|
2015-02-24 23:51:55 -05:00
|
|
|
from pyspark.sql.types import *
|
[SPARK-19507][SPARK-21296][PYTHON] Avoid per-record type dispatch in schema verification and improve exception message
## What changes were proposed in this pull request?
**Context**
While reviewing https://github.com/apache/spark/pull/17227, I realised here we type-dispatch per record. The PR itself is fine in terms of performance as is but this prints a prefix, `"obj"` in exception message as below:
```
from pyspark.sql.types import *
schema = StructType([StructField('s', IntegerType(), nullable=False)])
spark.createDataFrame([["1"]], schema)
...
TypeError: obj.s: IntegerType can not accept object '1' in type <type 'str'>
```
I suggested to get rid of this but during investigating this, I realised my approach might bring a performance regression as it is a hot path.
Only for SPARK-19507 and https://github.com/apache/spark/pull/17227, It needs more changes to cleanly get rid of the prefix and I rather decided to fix both issues together.
**Propersal**
This PR tried to
- get rid of per-record type dispatch as we do in many code paths in Scala so that it improves the performance (roughly ~25% improvement) - SPARK-21296
This was tested with a simple code `spark.createDataFrame(range(1000000), "int")`. However, I am quite sure the actual improvement in practice is larger than this, in particular, when the schema is complicated.
- improve error message in exception describing field information as prose - SPARK-19507
## How was this patch tested?
Manually tested and unit tests were added in `python/pyspark/sql/tests.py`.
Benchmark - codes: https://gist.github.com/HyukjinKwon/c3397469c56cb26c2d7dd521ed0bc5a3
Error message - codes: https://gist.github.com/HyukjinKwon/b1b2c7f65865444c4a8836435100e398
**Before**
Benchmark:
- Results: https://gist.github.com/HyukjinKwon/4a291dab45542106301a0c1abcdca924
Error message
- Results: https://gist.github.com/HyukjinKwon/57b1916395794ce924faa32b14a3fe19
**After**
Benchmark
- Results: https://gist.github.com/HyukjinKwon/21496feecc4a920e50c4e455f836266e
Error message
- Results: https://gist.github.com/HyukjinKwon/7a494e4557fe32a652ce1236e504a395
Closes #17227
Author: hyukjinkwon <gurwls223@gmail.com>
Author: David Gingrich <david@textio.com>
Closes #18521 from HyukjinKwon/python-type-dispatch.
2017-07-04 08:45:58 -04:00
|
|
|
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
|
[SPARK-16542][SQL][PYSPARK] Fix bugs about types that result an array of null when creating DataFrame using python
## What changes were proposed in this pull request?
This is the reopen of https://github.com/apache/spark/pull/14198, with merge conflicts resolved.
ueshin Could you please take a look at my code?
Fix bugs about types that result an array of null when creating DataFrame using python.
Python's array.array have richer type than python itself, e.g. we can have `array('f',[1,2,3])` and `array('d',[1,2,3])`. Codes in spark-sql and pyspark didn't take this into consideration which might cause a problem that you get an array of null values when you have `array('f')` in your rows.
A simple code to reproduce this bug is:
```
from pyspark import SparkContext
from pyspark.sql import SQLContext,Row,DataFrame
from array import array
sc = SparkContext()
sqlContext = SQLContext(sc)
row1 = Row(floatarray=array('f',[1,2,3]), doublearray=array('d',[1,2,3]))
rows = sc.parallelize([ row1 ])
df = sqlContext.createDataFrame(rows)
df.show()
```
which have output
```
+---------------+------------------+
| doublearray| floatarray|
+---------------+------------------+
|[1.0, 2.0, 3.0]|[null, null, null]|
+---------------+------------------+
```
## How was this patch tested?
New test case added
Author: Xiang Gao <qasdfgtyuiop@gmail.com>
Author: Gao, Xiang <qasdfgtyuiop@gmail.com>
Author: Takuya UESHIN <ueshin@databricks.com>
Closes #18444 from zasdfgbnm/fix_array_infer.
2017-07-19 23:46:06 -04:00
|
|
|
from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
|
|
|
|
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
|
2018-01-08 00:32:05 -05:00
|
|
|
from pyspark.sql.types import _merge_type
|
2018-01-31 06:04:51 -05:00
|
|
|
from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests
|
2016-12-15 17:26:54 -05:00
|
|
|
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
|
2015-05-23 11:30:05 -04:00
|
|
|
from pyspark.sql.window import Window
|
2016-03-28 15:31:12 -04:00
|
|
|
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
class UTCOffsetTimezone(datetime.tzinfo):
|
|
|
|
"""
|
|
|
|
Specifies timezone in UTC offset
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, offset=0):
|
|
|
|
self.ZERO = datetime.timedelta(hours=offset)
|
2015-06-11 04:00:41 -04:00
|
|
|
|
|
|
|
def utcoffset(self, dt):
|
|
|
|
return self.ZERO
|
|
|
|
|
|
|
|
def dst(self, dt):
|
|
|
|
return self.ZERO
|
|
|
|
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
class ExamplePointUDT(UserDefinedType):
|
|
|
|
"""
|
|
|
|
User-defined type (UDT) for ExamplePoint.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def sqlType(self):
|
|
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def module(cls):
|
2015-07-30 01:30:49 -04:00
|
|
|
return 'pyspark.sql.tests'
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def scalaUDT(cls):
|
|
|
|
return 'org.apache.spark.sql.test.ExamplePointUDT'
|
|
|
|
|
|
|
|
def serialize(self, obj):
|
|
|
|
return [obj.x, obj.y]
|
|
|
|
|
|
|
|
def deserialize(self, datum):
|
|
|
|
return ExamplePoint(datum[0], datum[1])
|
|
|
|
|
|
|
|
|
|
|
|
class ExamplePoint:
|
|
|
|
"""
|
|
|
|
An example class to demonstrate UDT in Scala, Java, and Python.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__UDT__ = ExamplePointUDT()
|
|
|
|
|
|
|
|
def __init__(self, x, y):
|
|
|
|
self.x = x
|
|
|
|
self.y = y
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "ExamplePoint(%s,%s)" % (self.x, self.y)
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return "(%s,%s)" % (self.x, self.y)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2015-07-30 01:30:49 -04:00
|
|
|
return isinstance(other, self.__class__) and \
|
2015-02-03 19:01:56 -05:00
|
|
|
other.x == self.x and other.y == self.y
|
|
|
|
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
class PythonOnlyUDT(UserDefinedType):
|
|
|
|
"""
|
|
|
|
User-defined type (UDT) for ExamplePoint.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def sqlType(self):
|
|
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def module(cls):
|
|
|
|
return '__main__'
|
|
|
|
|
|
|
|
def serialize(self, obj):
|
|
|
|
return [obj.x, obj.y]
|
|
|
|
|
|
|
|
def deserialize(self, datum):
|
|
|
|
return PythonOnlyPoint(datum[0], datum[1])
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def foo():
|
|
|
|
pass
|
|
|
|
|
|
|
|
@property
|
|
|
|
def props(self):
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
class PythonOnlyPoint(ExamplePoint):
|
|
|
|
"""
|
|
|
|
An example class to demonstrate UDT in only Python
|
|
|
|
"""
|
|
|
|
__UDT__ = PythonOnlyUDT()
|
|
|
|
|
|
|
|
|
2015-08-26 19:04:44 -04:00
|
|
|
class MyObject(object):
|
|
|
|
def __init__(self, key, value):
|
|
|
|
self.key = key
|
|
|
|
self.value = value
|
|
|
|
|
|
|
|
|
2017-10-29 22:50:22 -04:00
|
|
|
class ReusedSQLTestCase(ReusedPySparkTestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ReusedPySparkTestCase.setUpClass()
|
|
|
|
cls.spark = SparkSession(cls.sc)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ReusedPySparkTestCase.tearDownClass()
|
|
|
|
cls.spark.stop()
|
|
|
|
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
@contextmanager
|
|
|
|
def sql_conf(self, pairs):
|
|
|
|
"""
|
|
|
|
A convenient context manager to test some configuration specific logic. This sets
|
|
|
|
`value` to the configuration `key` and then restores it back when it exits.
|
|
|
|
"""
|
|
|
|
assert isinstance(pairs, dict), "pairs should be a dictionary."
|
|
|
|
|
|
|
|
keys = pairs.keys()
|
|
|
|
new_values = pairs.values()
|
|
|
|
old_values = [self.spark.conf.get(key, None) for key in keys]
|
|
|
|
for key, new_value in zip(keys, new_values):
|
|
|
|
self.spark.conf.set(key, new_value)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
for key, old_value in zip(keys, old_values):
|
|
|
|
if old_value is None:
|
|
|
|
self.spark.conf.unset(key)
|
|
|
|
else:
|
|
|
|
self.spark.conf.set(key, old_value)
|
|
|
|
|
2018-01-23 00:11:30 -05:00
|
|
|
def assertPandasEqual(self, expected, result):
|
|
|
|
msg = ("DataFrames are not equal: " +
|
|
|
|
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
|
|
|
|
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
|
|
|
|
self.assertTrue(expected.equals(result), msg=msg)
|
|
|
|
|
2017-10-29 22:50:22 -04:00
|
|
|
|
2015-02-27 23:07:17 -05:00
|
|
|
class DataTypeTests(unittest.TestCase):
|
|
|
|
# regression test for SPARK-6055
|
|
|
|
def test_data_type_eq(self):
|
|
|
|
lt = LongType()
|
|
|
|
lt2 = pickle.loads(pickle.dumps(LongType()))
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(lt, lt2)
|
2015-02-27 23:07:17 -05:00
|
|
|
|
2015-05-31 22:55:57 -04:00
|
|
|
# regression test for SPARK-7978
|
|
|
|
def test_decimal_type(self):
|
|
|
|
t1 = DecimalType()
|
|
|
|
t2 = DecimalType(10, 2)
|
|
|
|
self.assertTrue(t2 is not t1)
|
|
|
|
self.assertNotEqual(t1, t2)
|
|
|
|
t3 = DecimalType(8)
|
|
|
|
self.assertNotEqual(t2, t3)
|
|
|
|
|
2015-09-01 17:58:49 -04:00
|
|
|
# regression test for SPARK-10392
|
|
|
|
def test_datetype_equal_zero(self):
|
|
|
|
dt = DateType()
|
|
|
|
self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
|
|
|
|
|
[SPARK-17035] [SQL] [PYSPARK] Improve Timestamp not to lose precision for all cases
## What changes were proposed in this pull request?
`PySpark` loses `microsecond` precision for some corner cases during converting `Timestamp` into `Long`. For example, for the following `datetime.max` value should be converted a value whose last 6 digits are '999999'. This PR improves the logic not to lose precision for all cases.
**Corner case**
```python
>>> datetime.datetime.max
datetime.datetime(9999, 12, 31, 23, 59, 59, 999999)
```
**Before**
```python
>>> from datetime import datetime
>>> from pyspark.sql import Row
>>> from pyspark.sql.types import StructType, StructField, TimestampType
>>> schema = StructType([StructField("dt", TimestampType(), False)])
>>> [schema.toInternal(row) for row in [{"dt": datetime.max}]]
[(253402329600000000,)]
```
**After**
```python
>>> [schema.toInternal(row) for row in [{"dt": datetime.max}]]
[(253402329599999999,)]
```
## How was this patch tested?
Pass the Jenkins test with a new test case.
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #14631 from dongjoon-hyun/SPARK-17035.
2016-08-16 13:01:30 -04:00
|
|
|
# regression test for SPARK-17035
|
|
|
|
def test_timestamp_microsecond(self):
|
|
|
|
tst = TimestampType()
|
|
|
|
self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999)
|
|
|
|
|
2016-06-21 13:53:33 -04:00
|
|
|
def test_empty_row(self):
|
|
|
|
row = Row()
|
|
|
|
self.assertEqual(len(row), 0)
|
|
|
|
|
2017-09-10 04:47:45 -04:00
|
|
|
def test_struct_field_type_name(self):
|
|
|
|
struct_field = StructField("a", IntegerType())
|
|
|
|
self.assertRaises(TypeError, struct_field.typeName)
|
|
|
|
|
2015-02-27 23:07:17 -05:00
|
|
|
|
2017-10-29 22:50:22 -04:00
|
|
|
class SQLTests(ReusedSQLTestCase):
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
2017-10-29 22:50:22 -04:00
|
|
|
ReusedSQLTestCase.setUpClass()
|
2015-02-03 19:01:56 -05:00
|
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
|
|
|
os.unlink(cls.tempdir.name)
|
|
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
2016-05-11 14:24:16 -04:00
|
|
|
cls.df = cls.spark.createDataFrame(cls.testData)
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
2017-10-29 22:50:22 -04:00
|
|
|
ReusedSQLTestCase.tearDownClass()
|
2015-02-03 19:01:56 -05:00
|
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
|
2017-01-13 05:35:12 -05:00
|
|
|
def test_sqlcontext_reuses_sparksession(self):
|
|
|
|
sqlContext1 = SQLContext(self.sc)
|
|
|
|
sqlContext2 = SQLContext(self.sc)
|
|
|
|
self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
|
|
|
|
|
2017-05-07 22:58:27 -04:00
|
|
|
def tearDown(self):
|
|
|
|
super(SQLTests, self).tearDown()
|
|
|
|
|
|
|
|
# tear down test_bucketed_write state
|
|
|
|
self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket")
|
|
|
|
|
2015-08-08 11:38:18 -04:00
|
|
|
def test_row_should_be_read_only(self):
|
|
|
|
row = Row(a=1, b=2)
|
|
|
|
self.assertEqual(1, row.a)
|
|
|
|
|
|
|
|
def foo():
|
|
|
|
row.a = 3
|
|
|
|
self.assertRaises(Exception, foo)
|
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
row2 = self.spark.range(10).first()
|
2015-08-08 11:38:18 -04:00
|
|
|
self.assertEqual(0, row2.id)
|
|
|
|
|
|
|
|
def foo2():
|
|
|
|
row2.id = 2
|
|
|
|
self.assertRaises(Exception, foo2)
|
|
|
|
|
2015-05-19 00:43:12 -04:00
|
|
|
def test_range(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.range(1, 1).count(), 0)
|
|
|
|
self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
|
|
|
|
self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
|
|
|
|
self.assertEqual(self.spark.range(-2).count(), 0)
|
|
|
|
self.assertEqual(self.spark.range(3).count(), 3)
|
2015-05-19 00:43:12 -04:00
|
|
|
|
2015-07-09 17:43:38 -04:00
|
|
|
def test_duplicated_column_names(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
|
2015-07-09 17:43:38 -04:00
|
|
|
row = df.select('*').first()
|
|
|
|
self.assertEqual(1, row[0])
|
|
|
|
self.assertEqual(2, row[1])
|
|
|
|
self.assertEqual("Row(c=1, c=2)", str(row))
|
|
|
|
# Cannot access columns
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
|
|
|
|
|
[SPARK-15244] [PYTHON] Type of column name created with createDataFrame is not consistent.
## What changes were proposed in this pull request?
**createDataFrame** returns inconsistent types for column names.
```python
>>> from pyspark.sql.types import StructType, StructField, StringType
>>> schema = StructType([StructField(u"col", StringType())])
>>> df1 = spark.createDataFrame([("a",)], schema)
>>> df1.columns # "col" is str
['col']
>>> df2 = spark.createDataFrame([("a",)], [u"col"])
>>> df2.columns # "col" is unicode
[u'col']
```
The reason is only **StructField** has the following code.
```
if not isinstance(name, str):
name = name.encode('utf-8')
```
This PR adds the same logic into **createDataFrame** for consistency.
```
if isinstance(schema, list):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
```
## How was this patch tested?
Pass the Jenkins test (with new python doctest)
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #13097 from dongjoon-hyun/SPARK-15244.
2016-05-17 16:05:07 -04:00
|
|
|
def test_column_name_encoding(self):
|
|
|
|
"""Ensure that created columns has `str` type consistently."""
|
|
|
|
columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns
|
|
|
|
self.assertEqual(columns, ['name', 'age'])
|
|
|
|
self.assertTrue(isinstance(columns[0], str))
|
|
|
|
self.assertTrue(isinstance(columns[1], str))
|
|
|
|
|
2015-05-14 22:49:44 -04:00
|
|
|
def test_explode(self):
|
2017-06-21 17:59:52 -04:00
|
|
|
from pyspark.sql.functions import explode, explode_outer, posexplode_outer
|
|
|
|
d = [
|
|
|
|
Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
|
|
|
|
Row(a=1, intlist=[], mapfield={}),
|
|
|
|
Row(a=1, intlist=None, mapfield=None),
|
|
|
|
]
|
2015-05-14 22:49:44 -04:00
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
data = self.spark.createDataFrame(rdd)
|
2015-05-14 22:49:44 -04:00
|
|
|
|
|
|
|
result = data.select(explode(data.intlist).alias("a")).select("a").collect()
|
|
|
|
self.assertEqual(result[0][0], 1)
|
|
|
|
self.assertEqual(result[1][0], 2)
|
|
|
|
self.assertEqual(result[2][0], 3)
|
|
|
|
|
|
|
|
result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
|
|
|
|
self.assertEqual(result[0][0], "a")
|
|
|
|
self.assertEqual(result[0][1], "b")
|
|
|
|
|
2017-06-21 17:59:52 -04:00
|
|
|
result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
|
|
|
|
self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
|
|
|
|
|
|
|
|
result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
|
|
|
|
self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
|
|
|
|
|
|
|
|
result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
|
|
|
|
self.assertEqual(result, [1, 2, 3, None, None])
|
|
|
|
|
|
|
|
result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
|
|
|
|
self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
|
|
|
|
|
2015-06-23 18:51:16 -04:00
|
|
|
def test_and_in_expression(self):
|
|
|
|
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
|
|
|
|
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
|
|
|
|
self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
|
|
|
|
self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
|
|
|
|
self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
|
|
|
|
self.assertRaises(ValueError, lambda: not self.df.key == 1)
|
|
|
|
|
2015-04-01 20:23:57 -04:00
|
|
|
def test_udf_with_callable(self):
|
|
|
|
d = [Row(number=i, squared=i**2) for i in range(10)]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
data = self.spark.createDataFrame(rdd)
|
2015-04-01 20:23:57 -04:00
|
|
|
|
|
|
|
class PlusFour:
|
|
|
|
def __call__(self, col):
|
|
|
|
if col is not None:
|
|
|
|
return col + 4
|
|
|
|
|
|
|
|
call = PlusFour()
|
|
|
|
pudf = UserDefinedFunction(call, LongType())
|
|
|
|
res = data.select(pudf(data['number']).alias('plus_four'))
|
|
|
|
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
|
|
|
|
|
|
|
|
def test_udf_with_partial_function(self):
|
|
|
|
d = [Row(number=i, squared=i**2) for i in range(10)]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
data = self.spark.createDataFrame(rdd)
|
2015-04-01 20:23:57 -04:00
|
|
|
|
|
|
|
def some_func(col, param):
|
|
|
|
if col is not None:
|
|
|
|
return col + param
|
|
|
|
|
|
|
|
pfunc = functools.partial(some_func, param=4)
|
|
|
|
pudf = UserDefinedFunction(pfunc, LongType())
|
|
|
|
res = data.select(pudf(data['number']).alias('plus_four'))
|
|
|
|
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_udf(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(row[0], 5)
|
|
|
|
|
2018-01-18 00:51:05 -05:00
|
|
|
# This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
|
|
|
|
sqlContext = self.spark._wrapped
|
|
|
|
sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
|
|
|
|
[row] = sqlContext.sql("SELECT oneArg('test')").collect()
|
|
|
|
self.assertEqual(row[0], 4)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_udf2(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
|
2016-05-17 21:01:59 -04:00
|
|
|
self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
|
|
|
|
.createOrReplaceTempView("test")
|
2016-05-11 14:24:16 -04:00
|
|
|
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(4, res[0])
|
|
|
|
|
2018-01-04 08:07:31 -05:00
|
|
|
def test_udf3(self):
|
2018-01-16 06:20:33 -05:00
|
|
|
two_args = self.spark.catalog.registerFunction(
|
|
|
|
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
|
|
|
|
self.assertEqual(two_args.deterministic, True)
|
|
|
|
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
|
|
|
|
self.assertEqual(row[0], u'5')
|
|
|
|
|
|
|
|
def test_udf_registration_return_type_none(self):
|
|
|
|
two_args = self.spark.catalog.registerFunction(
|
|
|
|
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None)
|
|
|
|
self.assertEqual(two_args.deterministic, True)
|
2018-01-04 08:07:31 -05:00
|
|
|
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
|
|
|
|
self.assertEqual(row[0], 5)
|
|
|
|
|
2018-01-16 06:20:33 -05:00
|
|
|
def test_udf_registration_return_type_not_none(self):
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(TypeError, "Invalid returnType"):
|
|
|
|
self.spark.catalog.registerFunction(
|
|
|
|
"f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
|
|
|
|
|
2018-01-04 08:07:31 -05:00
|
|
|
def test_nondeterministic_udf(self):
|
2018-01-06 03:11:20 -05:00
|
|
|
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
|
2018-01-04 08:07:31 -05:00
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
import random
|
|
|
|
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
|
|
|
|
self.assertEqual(udf_random_col.deterministic, False)
|
|
|
|
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
|
|
|
|
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
|
|
|
|
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
|
|
|
|
self.assertEqual(row[0] + 10, row[1])
|
|
|
|
|
|
|
|
def test_nondeterministic_udf2(self):
|
|
|
|
import random
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
|
|
|
|
self.assertEqual(random_udf.deterministic, False)
|
2018-01-16 06:20:33 -05:00
|
|
|
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
|
2018-01-04 08:07:31 -05:00
|
|
|
self.assertEqual(random_udf1.deterministic, False)
|
|
|
|
[row] = self.spark.sql("SELECT randInt()").collect()
|
2018-01-16 06:20:33 -05:00
|
|
|
self.assertEqual(row[0], 6)
|
2018-01-04 08:07:31 -05:00
|
|
|
[row] = self.spark.range(1).select(random_udf1()).collect()
|
2018-01-16 06:20:33 -05:00
|
|
|
self.assertEqual(row[0], 6)
|
2018-01-04 08:07:31 -05:00
|
|
|
[row] = self.spark.range(1).select(random_udf()).collect()
|
|
|
|
self.assertEqual(row[0], 6)
|
|
|
|
# render_doc() reproduces the help() exception without printing output
|
|
|
|
pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
|
|
|
|
pydoc.render_doc(random_udf)
|
|
|
|
pydoc.render_doc(random_udf1)
|
2018-01-06 10:08:26 -05:00
|
|
|
pydoc.render_doc(udf(lambda x: x).asNondeterministic)
|
2018-01-04 08:07:31 -05:00
|
|
|
|
2018-01-27 14:26:09 -05:00
|
|
|
def test_nondeterministic_udf3(self):
|
|
|
|
# regression test for SPARK-23233
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
f = udf(lambda x: x)
|
|
|
|
# Here we cache the JVM UDF instance.
|
|
|
|
self.spark.range(1).select(f("id"))
|
|
|
|
# This should reset the cache to set the deterministic status correctly.
|
|
|
|
f = f.asNondeterministic()
|
|
|
|
# Check the deterministic status of udf.
|
|
|
|
df = self.spark.range(1).select(f("id"))
|
|
|
|
deterministic = df._jdf.logicalPlan().projectList().head().deterministic()
|
|
|
|
self.assertFalse(deterministic)
|
|
|
|
|
2018-01-06 03:11:20 -05:00
|
|
|
def test_nondeterministic_udf_in_aggregate(self):
|
|
|
|
from pyspark.sql.functions import udf, sum
|
|
|
|
import random
|
|
|
|
udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
|
|
|
|
df = self.spark.range(10)
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
|
|
|
|
df.groupby('id').agg(sum(udf_random_col())).collect()
|
|
|
|
with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
|
|
|
|
df.agg(sum(udf_random_col())).collect()
|
|
|
|
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
def test_chained_udf(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(1)").collect()
|
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
2016-03-29 18:06:29 -04:00
|
|
|
self.assertEqual(row[0], 2)
|
2016-05-11 14:24:16 -04:00
|
|
|
[row] = self.spark.sql("SELECT double(double(1))").collect()
|
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
2016-03-29 18:06:29 -04:00
|
|
|
self.assertEqual(row[0], 4)
|
2016-05-11 14:24:16 -04:00
|
|
|
[row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
|
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
2016-03-29 18:06:29 -04:00
|
|
|
self.assertEqual(row[0], 6)
|
|
|
|
|
2017-05-10 19:50:57 -04:00
|
|
|
def test_single_udf_with_repeated_argument(self):
|
|
|
|
# regression test for SPARK-20685
|
|
|
|
self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
|
|
|
|
row = self.spark.sql("SELECT add(1, 1)").first()
|
|
|
|
self.assertEqual(tuple(row), (2, ))
|
|
|
|
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
def test_multiple_udfs(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(1), double(2)").collect()
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
self.assertEqual(tuple(row), (2, 4))
|
2016-05-11 14:24:16 -04:00
|
|
|
[row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
self.assertEqual(tuple(row), (4, 12))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
self.assertEqual(tuple(row), (6, 5))
|
|
|
|
|
2016-09-19 16:24:16 -04:00
|
|
|
def test_udf_in_filter_on_top_of_outer_join(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
left = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
right = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
df = left.join(right, on='a', how='left_outer')
|
|
|
|
df = df.withColumn('b', udf(lambda x: 'x')(df.a))
|
|
|
|
self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
|
|
|
|
|
2017-01-20 19:11:40 -05:00
|
|
|
def test_udf_in_filter_on_top_of_join(self):
|
|
|
|
# regression test for SPARK-18589
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
left = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
right = self.spark.createDataFrame([Row(b=1)])
|
|
|
|
f = udf(lambda a, b: a == b, BooleanType())
|
|
|
|
df = left.crossJoin(right).filter(f("a", "b"))
|
|
|
|
self.assertEqual(df.collect(), [Row(a=1, b=1)])
|
|
|
|
|
2016-06-21 13:53:33 -04:00
|
|
|
def test_udf_without_arguments(self):
|
|
|
|
self.spark.catalog.registerFunction("foo", lambda: "bar")
|
|
|
|
[row] = self.spark.sql("SELECT foo()").collect()
|
|
|
|
self.assertEqual(row[0], "bar")
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_udf_with_array_type(self):
|
2015-04-16 19:20:57 -04:00
|
|
|
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
|
2015-02-03 19:01:56 -05:00
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-17 21:01:59 -04:00
|
|
|
self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
|
|
|
|
self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
|
|
|
|
[(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
|
2015-04-16 19:20:57 -04:00
|
|
|
self.assertEqual(list(range(3)), l1)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, l2)
|
|
|
|
|
|
|
|
def test_broadcast_in_udf(self):
|
|
|
|
bar = {"a": "aa", "b": "bb", "c": "abc"}
|
|
|
|
foo = self.sc.broadcast(bar)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
|
|
|
|
[res] = self.spark.sql("SELECT MYUDF('c')").collect()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual("abc", res[0])
|
2016-05-11 14:24:16 -04:00
|
|
|
[res] = self.spark.sql("SELECT MYUDF('')").collect()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual("", res[0])
|
|
|
|
|
[SPARK-18766][SQL] Push Down Filter Through BatchEvalPython (Python UDF)
### What changes were proposed in this pull request?
Currently, when users use Python UDF in Filter, BatchEvalPython is always generated below FilterExec. However, not all the predicates need to be evaluated after Python UDF execution. Thus, this PR is to push down the determinisitc predicates through `BatchEvalPython`.
```Python
>>> df = spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
>>> from pyspark.sql.functions import udf, col
>>> from pyspark.sql.types import BooleanType
>>> my_filter = udf(lambda a: a < 2, BooleanType())
>>> sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
>>> sel.explain(True)
```
Before the fix, the plan looks like
```
== Optimized Logical Plan ==
Filter ((isnotnull(value#1) && <lambda>(key#0L)) && (value#1 < 2))
+- LogicalRDD [key#0L, value#1]
== Physical Plan ==
*Project [key#0L, value#1]
+- *Filter ((isnotnull(value#1) && pythonUDF0#9) && (value#1 < 2))
+- BatchEvalPython [<lambda>(key#0L)], [key#0L, value#1, pythonUDF0#9]
+- Scan ExistingRDD[key#0L,value#1]
```
After the fix, the plan looks like
```
== Optimized Logical Plan ==
Filter ((isnotnull(value#1) && <lambda>(key#0L)) && (value#1 < 2))
+- LogicalRDD [key#0L, value#1]
== Physical Plan ==
*Project [key#0L, value#1]
+- *Filter pythonUDF0#9: boolean
+- BatchEvalPython [<lambda>(key#0L)], [key#0L, value#1, pythonUDF0#9]
+- *Filter (isnotnull(value#1) && (value#1 < 2))
+- Scan ExistingRDD[key#0L,value#1]
```
### How was this patch tested?
Added both unit test cases for `BatchEvalPythonExec` and also add an end-to-end test case in Python test suite.
Author: gatorsmile <gatorsmile@gmail.com>
Closes #16193 from gatorsmile/pythonUDFPredicatePushDown.
2016-12-10 11:47:45 -05:00
|
|
|
def test_udf_with_filter_function(self):
|
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
from pyspark.sql.functions import udf, col
|
|
|
|
from pyspark.sql.types import BooleanType
|
|
|
|
|
|
|
|
my_filter = udf(lambda a: a < 2, BooleanType())
|
|
|
|
sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
|
|
|
|
self.assertEqual(sel.collect(), [Row(key=1, value='1')])
|
|
|
|
|
2016-04-04 13:56:26 -04:00
|
|
|
def test_udf_with_aggregate_function(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
[SPARK-15888] [SQL] fix Python UDF with aggregate
## What changes were proposed in this pull request?
After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate.
## How was this patch tested?
Added regression tests. The plan of added test query looks like this:
```
== Parsed Logical Plan ==
'Project [<lambda>('k, 's) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Analyzed Logical Plan ==
t: int
Project [<lambda>(k#17, s#22L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Optimized Logical Plan ==
Project [<lambda>(agg#29, agg#30L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6) as bigint)) AS agg#30L]
+- LogicalRDD [key#5L, value#6]
== Physical Plan ==
*Project [pythonUDF0#37 AS t#26]
+- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37]
+- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L])
+- Exchange hashpartitioning(<lambda>(key#5L)#31, 200)
+- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[<lambda>(key#5L)#31,sum#33L])
+- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35]
+- Scan ExistingRDD[key#5L,value#6]
```
Author: Davies Liu <davies@databricks.com>
Closes #13682 from davies/fix_py_udf.
2016-06-15 16:38:04 -04:00
|
|
|
from pyspark.sql.functions import udf, col, sum
|
2016-04-04 13:56:26 -04:00
|
|
|
from pyspark.sql.types import BooleanType
|
|
|
|
|
|
|
|
my_filter = udf(lambda a: a == 1, BooleanType())
|
|
|
|
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
|
|
|
|
self.assertEqual(sel.collect(), [Row(key=1)])
|
|
|
|
|
[SPARK-15888] [SQL] fix Python UDF with aggregate
## What changes were proposed in this pull request?
After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate.
## How was this patch tested?
Added regression tests. The plan of added test query looks like this:
```
== Parsed Logical Plan ==
'Project [<lambda>('k, 's) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Analyzed Logical Plan ==
t: int
Project [<lambda>(k#17, s#22L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Optimized Logical Plan ==
Project [<lambda>(agg#29, agg#30L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6) as bigint)) AS agg#30L]
+- LogicalRDD [key#5L, value#6]
== Physical Plan ==
*Project [pythonUDF0#37 AS t#26]
+- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37]
+- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L])
+- Exchange hashpartitioning(<lambda>(key#5L)#31, 200)
+- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[<lambda>(key#5L)#31,sum#33L])
+- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35]
+- Scan ExistingRDD[key#5L,value#6]
```
Author: Davies Liu <davies@databricks.com>
Closes #13682 from davies/fix_py_udf.
2016-06-15 16:38:04 -04:00
|
|
|
my_copy = udf(lambda x: x, IntegerType())
|
|
|
|
my_add = udf(lambda a, b: int(a + b), IntegerType())
|
|
|
|
my_strlen = udf(lambda x: len(x), IntegerType())
|
|
|
|
sel = df.groupBy(my_copy(col("key")).alias("k"))\
|
|
|
|
.agg(sum(my_strlen(col("value"))).alias("s"))\
|
|
|
|
.select(my_add(col("k"), col("s")).alias("t"))
|
|
|
|
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
|
|
|
|
|
[SPARK-16179][PYSPARK] fix bugs for Python udf in generate
## What changes were proposed in this pull request?
This PR fix the bug when Python UDF is used in explode (generator), GenerateExec requires that all the attributes in expressions should be resolvable from children when creating, we should replace the children first, then replace it's expressions.
```
>>> df.select(explode(f(*df))).show()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/vlad/dev/spark/python/pyspark/sql/dataframe.py", line 286, in show
print(self._jdf.showString(n, truncate))
File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py", line 933, in __call__
File "/home/vlad/dev/spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o52.showString.
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree:
Generate explode(<lambda>(_1#0L)), false, false, [col#15L]
+- Scan ExistingRDD[_1#0L]
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50)
at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:387)
at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:69)
at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:45)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsDown(QueryPlan.scala:177)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressions(QueryPlan.scala:144)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(ExtractPythonUDFs.scala:153)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:114)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:113)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:300)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:298)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:113)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:93)
at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95)
at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95)
at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
at scala.collection.immutable.List.foldLeft(List.scala:84)
at org.apache.spark.sql.execution.QueryExecution.prepareForExecution(QueryExecution.scala:95)
at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:85)
at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:85)
at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2557)
at org.apache.spark.sql.Dataset.head(Dataset.scala:1923)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2138)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:239)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:280)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:211)
at java.lang.Thread.run(Thread.java:745)
Caused by: java.lang.reflect.InvocationTargetException
at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
at java.lang.reflect.Constructor.newInstance(Constructor.java:423)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:412)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:387)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49)
... 42 more
Caused by: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#20
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:88)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:87)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:278)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:284)
at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:268)
at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:87)
at org.apache.spark.sql.execution.GenerateExec.<init>(GenerateExec.scala:63)
... 52 more
Caused by: java.lang.RuntimeException: Couldn't find pythonUDF0#20 in [_1#0L]
at scala.sys.package$.error(package.scala:27)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:94)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:88)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49)
... 67 more
```
## How was this patch tested?
Added regression tests.
Author: Davies Liu <davies@databricks.com>
Closes #13883 from davies/udf_in_generate.
2016-06-24 18:20:39 -04:00
|
|
|
def test_udf_in_generate(self):
|
|
|
|
from pyspark.sql.functions import udf, explode
|
|
|
|
df = self.spark.range(5)
|
|
|
|
f = udf(lambda x: list(range(x)), ArrayType(LongType()))
|
|
|
|
row = df.select(explode(f(*df))).groupBy().sum().first()
|
|
|
|
self.assertEqual(row[0], 10)
|
|
|
|
|
[SPARK-18634][PYSPARK][SQL] Corruption and Correctness issues with exploding Python UDFs
## What changes were proposed in this pull request?
As reported in the Jira, there are some weird issues with exploding Python UDFs in SparkSQL.
The following test code can reproduce it. Notice: the following test code is reported to return wrong results in the Jira. However, as I tested on master branch, it causes exception and so can't return any result.
>>> from pyspark.sql.functions import *
>>> from pyspark.sql.types import *
>>>
>>> df = spark.range(10)
>>>
>>> def return_range(value):
... return [(i, str(i)) for i in range(value - 1, value + 1)]
...
>>> range_udf = udf(return_range, ArrayType(StructType([StructField("integer_val", IntegerType()),
... StructField("string_val", StringType())])))
>>>
>>> df.select("id", explode(range_udf(df.id))).show()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/spark/python/pyspark/sql/dataframe.py", line 318, in show
print(self._jdf.showString(n, 20))
File "/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__
File "/spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 319, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o126.showString.: java.lang.AssertionError: assertion failed
at scala.Predef$.assert(Predef.scala:156)
at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:120)
at org.apache.spark.sql.execution.GenerateExec.consume(GenerateExec.scala:57)
The cause of this issue is, in `ExtractPythonUDFs` we insert `BatchEvalPythonExec` to run PythonUDFs in batch. `BatchEvalPythonExec` will add extra outputs (e.g., `pythonUDF0`) to original plan. In above case, the original `Range` only has one output `id`. After `ExtractPythonUDFs`, the added `BatchEvalPythonExec` has two outputs `id` and `pythonUDF0`.
Because the output of `GenerateExec` is given after analysis phase, in above case, it is the combination of `id`, i.e., the output of `Range`, and `col`. But in planning phase, we change `GenerateExec`'s child plan to `BatchEvalPythonExec` with additional output attributes.
It will cause no problem in non wholestage codegen. Because when evaluating the additional attributes are projected out the final output of `GenerateExec`.
However, as `GenerateExec` now supports wholestage codegen, the framework will input all the outputs of the child plan to `GenerateExec`. Then when consuming `GenerateExec`'s output data (i.e., calling `consume`), the number of output attributes is different to the output variables in wholestage codegen.
To solve this issue, this patch only gives the generator's output to `GenerateExec` after analysis phase. `GenerateExec`'s output is the combination of its child plan's output and the generator's output. So when we change `GenerateExec`'s child, its output is still correct.
## How was this patch tested?
Added test cases to PySpark.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #16120 from viirya/fix-py-udf-with-generator.
2016-12-05 20:50:43 -05:00
|
|
|
df = self.spark.range(3)
|
|
|
|
res = df.select("id", explode(f(df.id))).collect()
|
|
|
|
self.assertEqual(res[0][0], 1)
|
|
|
|
self.assertEqual(res[0][1], 0)
|
|
|
|
self.assertEqual(res[1][0], 2)
|
|
|
|
self.assertEqual(res[1][1], 0)
|
|
|
|
self.assertEqual(res[2][0], 2)
|
|
|
|
self.assertEqual(res[2][1], 1)
|
|
|
|
|
|
|
|
range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
|
|
|
|
res = df.select("id", explode(range_udf(df.id))).collect()
|
|
|
|
self.assertEqual(res[0][0], 0)
|
|
|
|
self.assertEqual(res[0][1], -1)
|
|
|
|
self.assertEqual(res[1][0], 0)
|
|
|
|
self.assertEqual(res[1][1], 0)
|
|
|
|
self.assertEqual(res[2][0], 1)
|
|
|
|
self.assertEqual(res[2][1], 0)
|
|
|
|
self.assertEqual(res[3][0], 1)
|
|
|
|
self.assertEqual(res[3][1], 1)
|
|
|
|
|
2016-09-12 19:35:42 -04:00
|
|
|
def test_udf_with_order_by_and_limit(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
my_copy = udf(lambda x: x, IntegerType())
|
|
|
|
df = self.spark.range(10).orderBy("id")
|
|
|
|
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
|
|
|
|
res.explain(True)
|
|
|
|
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
|
|
|
|
|
2017-05-07 01:28:42 -04:00
|
|
|
def test_udf_registration_returns_udf(self):
|
|
|
|
df = self.spark.range(10)
|
|
|
|
add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
df.selectExpr("add_three(id) AS plus_three").collect(),
|
|
|
|
df.select(add_three("id").alias("plus_three")).collect()
|
|
|
|
)
|
|
|
|
|
2018-01-18 00:51:05 -05:00
|
|
|
# This is to check if a 'SQLContext.udf' can call its alias.
|
|
|
|
sqlContext = self.spark._wrapped
|
|
|
|
add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
df.selectExpr("add_four(id) AS plus_four").collect(),
|
|
|
|
df.select(add_four("id").alias("plus_four")).collect()
|
|
|
|
)
|
|
|
|
|
2017-07-05 13:59:10 -04:00
|
|
|
def test_non_existed_udf(self):
|
|
|
|
spark = self.spark
|
|
|
|
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
|
|
|
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
|
|
|
|
|
2018-01-18 00:51:05 -05:00
|
|
|
# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
|
|
|
|
sqlContext = spark._wrapped
|
|
|
|
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
|
|
|
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
|
|
|
|
|
2017-07-05 13:59:10 -04:00
|
|
|
def test_non_existed_udaf(self):
|
|
|
|
spark = self.spark
|
|
|
|
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
|
|
|
|
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
|
|
|
|
|
2017-06-15 01:18:19 -04:00
|
|
|
def test_multiLine_json(self):
|
2017-02-16 23:51:19 -05:00
|
|
|
people1 = self.spark.read.json("python/test_support/sql/people.json")
|
|
|
|
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
|
2017-06-15 01:18:19 -04:00
|
|
|
multiLine=True)
|
2017-02-16 23:51:19 -05:00
|
|
|
self.assertEqual(people1.collect(), people_array.collect())
|
|
|
|
|
2017-06-15 01:18:19 -04:00
|
|
|
def test_multiline_csv(self):
|
2017-02-28 16:34:33 -05:00
|
|
|
ages_newlines = self.spark.read.csv(
|
2017-06-15 01:18:19 -04:00
|
|
|
"python/test_support/sql/ages_newlines.csv", multiLine=True)
|
2017-02-28 16:34:33 -05:00
|
|
|
expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
|
|
|
|
Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
|
|
|
|
Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
|
|
|
|
self.assertEqual(ages_newlines.collect(), expected)
|
|
|
|
|
[SPARK-18579][SQL] Use ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options in CSV writing
## What changes were proposed in this pull request?
This PR proposes to support _not_ trimming the white spaces when writing out. These are `false` by default in CSV reading path but these are `true` by default in CSV writing in univocity parser.
Both `ignoreLeadingWhiteSpace` and `ignoreTrailingWhiteSpace` options are not being used for writing and therefore, we are always trimming the white spaces.
It seems we should provide a way to keep this white spaces easily.
WIth the data below:
```scala
val df = spark.read.csv(Seq("a , b , c").toDS)
df.show()
```
```
+---+----+---+
|_c0| _c1|_c2|
+---+----+---+
| a | b | c|
+---+----+---+
```
**Before**
```scala
df.write.csv("/tmp/text.csv")
spark.read.text("/tmp/text.csv").show()
```
```
+-----+
|value|
+-----+
|a,b,c|
+-----+
```
It seems this can't be worked around via `quoteAll` too.
```scala
df.write.option("quoteAll", true).csv("/tmp/text.csv")
spark.read.text("/tmp/text.csv").show()
```
```
+-----------+
| value|
+-----------+
|"a","b","c"|
+-----------+
```
**After**
```scala
df.write.option("ignoreLeadingWhiteSpace", false).option("ignoreTrailingWhiteSpace", false).csv("/tmp/text.csv")
spark.read.text("/tmp/text.csv").show()
```
```
+----------+
| value|
+----------+
|a , b , c|
+----------+
```
Note that this case is possible in R
```r
> system("cat text.csv")
f1,f2,f3
a , b , c
> df <- read.csv(file="text.csv")
> df
f1 f2 f3
1 a b c
> write.csv(df, file="text1.csv", quote=F, row.names=F)
> system("cat text1.csv")
f1,f2,f3
a , b , c
```
## How was this patch tested?
Unit tests in `CSVSuite` and manual tests for Python.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #17310 from HyukjinKwon/SPARK-18579.
2017-03-23 03:25:01 -04:00
|
|
|
def test_ignorewhitespace_csv(self):
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv(
|
|
|
|
tmpPath,
|
|
|
|
ignoreLeadingWhiteSpace=False,
|
|
|
|
ignoreTrailingWhiteSpace=False)
|
|
|
|
|
|
|
|
expected = [Row(value=u' a,b , c ')]
|
|
|
|
readback = self.spark.read.text(tmpPath)
|
|
|
|
self.assertEqual(readback.collect(), expected)
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2017-03-09 14:44:34 -05:00
|
|
|
def test_read_multiple_orc_file(self):
|
|
|
|
df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0",
|
|
|
|
"python/test_support/sql/orc_partitioned/b=1/c=1"])
|
|
|
|
self.assertEqual(2, df.count())
|
|
|
|
|
2016-12-08 10:22:18 -05:00
|
|
|
def test_udf_with_input_file_name(self):
|
|
|
|
from pyspark.sql.functions import udf, input_file_name
|
|
|
|
sourceFile = udf(lambda path: path, StringType())
|
|
|
|
filePath = "python/test_support/sql/people1.json"
|
|
|
|
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
|
|
|
|
self.assertTrue(row[0].find("people1.json") != -1)
|
|
|
|
|
[SPARK-19223][SQL][PYSPARK] Fix InputFileBlockHolder for datasources which are based on HadoopRDD or NewHadoopRDD
## What changes were proposed in this pull request?
For some datasources which are based on HadoopRDD or NewHadoopRDD, such as spark-xml, InputFileBlockHolder doesn't work with Python UDF.
The method to reproduce it is, running the following codes with `bin/pyspark --packages com.databricks:spark-xml_2.11:0.4.1`:
from pyspark.sql.functions import udf,input_file_name
from pyspark.sql.types import StringType
from pyspark.sql import SparkSession
def filename(path):
return path
session = SparkSession.builder.appName('APP').getOrCreate()
session.udf.register('sameText', filename)
sameText = udf(filename, StringType())
df = session.read.format('xml').load('a.xml', rowTag='root').select('*', input_file_name().alias('file'))
df.select('file').show() # works
df.select(sameText(df['file'])).show() # returns empty content
The issue is because in `HadoopRDD` and `NewHadoopRDD` we set the file block's info in `InputFileBlockHolder` before the returned iterator begins consuming. `InputFileBlockHolder` will record this info into thread local variable. When running Python UDF in batch, we set up another thread to consume the iterator from child plan's output rdd, so we can't read the info back in another thread.
To fix this, we have to set the info in `InputFileBlockHolder` after the iterator begins consuming. So the info can be read in correct thread.
## How was this patch tested?
Manual test with above example codes for spark-xml package on pyspark: `bin/pyspark --packages com.databricks:spark-xml_2.11:0.4.1`.
Added pyspark test.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #16585 from viirya/fix-inputfileblock-hadooprdd.
2017-01-18 10:06:44 -05:00
|
|
|
def test_udf_with_input_file_name_for_hadooprdd(self):
|
|
|
|
from pyspark.sql.functions import udf, input_file_name
|
|
|
|
|
|
|
|
def filename(path):
|
|
|
|
return path
|
|
|
|
|
|
|
|
sameText = udf(filename, StringType())
|
|
|
|
|
|
|
|
rdd = self.sc.textFile('python/test_support/sql/people.json')
|
|
|
|
df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
|
|
|
|
row = df.select(sameText(df['file'])).first()
|
|
|
|
self.assertTrue(row[0].find("people.json") != -1)
|
|
|
|
|
|
|
|
rdd2 = self.sc.newAPIHadoopFile(
|
|
|
|
'python/test_support/sql/people.json',
|
|
|
|
'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
|
|
|
|
'org.apache.hadoop.io.LongWritable',
|
|
|
|
'org.apache.hadoop.io.Text')
|
|
|
|
|
|
|
|
df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
|
|
|
|
row2 = df2.select(sameText(df2['file'])).first()
|
|
|
|
self.assertTrue(row2[0].find("people.json") != -1)
|
|
|
|
|
2017-01-31 21:03:39 -05:00
|
|
|
def test_udf_defers_judf_initalization(self):
|
|
|
|
# This is separate of UDFInitializationTests
|
|
|
|
# to avoid context initialization
|
|
|
|
# when udf is called
|
|
|
|
|
|
|
|
from pyspark.sql.functions import UserDefinedFunction
|
|
|
|
|
|
|
|
f = UserDefinedFunction(lambda x: x, StringType())
|
|
|
|
|
|
|
|
self.assertIsNone(
|
|
|
|
f._judf_placeholder,
|
|
|
|
"judf should not be initialized before the first call."
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
|
|
|
|
|
|
|
|
self.assertIsNotNone(
|
|
|
|
f._judf_placeholder,
|
|
|
|
"judf should be initialized after UDF has been called."
|
|
|
|
)
|
|
|
|
|
2017-02-13 13:37:34 -05:00
|
|
|
def test_udf_with_string_return_type(self):
|
|
|
|
from pyspark.sql.functions import UserDefinedFunction
|
|
|
|
|
|
|
|
add_one = UserDefinedFunction(lambda x: x + 1, "integer")
|
|
|
|
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
|
|
|
|
make_array = UserDefinedFunction(
|
|
|
|
lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
|
|
|
|
|
|
|
|
expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
|
|
|
|
actual = (self.spark.range(1, 2).toDF("x")
|
|
|
|
.select(add_one("x"), make_pair("x"), make_array("x"))
|
|
|
|
.first())
|
|
|
|
|
|
|
|
self.assertTupleEqual(expected, actual)
|
|
|
|
|
2017-02-14 12:46:22 -05:00
|
|
|
def test_udf_shouldnt_accept_noncallable_object(self):
|
|
|
|
from pyspark.sql.functions import UserDefinedFunction
|
|
|
|
|
|
|
|
non_callable = None
|
|
|
|
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
|
|
|
|
|
2017-02-15 13:16:34 -05:00
|
|
|
def test_udf_with_decorator(self):
|
|
|
|
from pyspark.sql.functions import lit, udf
|
|
|
|
from pyspark.sql.types import IntegerType, DoubleType
|
|
|
|
|
|
|
|
@udf(IntegerType())
|
|
|
|
def add_one(x):
|
|
|
|
if x is not None:
|
|
|
|
return x + 1
|
|
|
|
|
|
|
|
@udf(returnType=DoubleType())
|
|
|
|
def add_two(x):
|
|
|
|
if x is not None:
|
|
|
|
return float(x + 2)
|
|
|
|
|
|
|
|
@udf
|
|
|
|
def to_upper(x):
|
|
|
|
if x is not None:
|
|
|
|
return x.upper()
|
|
|
|
|
|
|
|
@udf()
|
|
|
|
def to_lower(x):
|
|
|
|
if x is not None:
|
|
|
|
return x.lower()
|
|
|
|
|
|
|
|
@udf
|
|
|
|
def substr(x, start, end):
|
|
|
|
if x is not None:
|
|
|
|
return x[start:end]
|
|
|
|
|
|
|
|
@udf("long")
|
|
|
|
def trunc(x):
|
|
|
|
return int(x)
|
|
|
|
|
|
|
|
@udf(returnType="double")
|
|
|
|
def as_double(x):
|
|
|
|
return float(x)
|
|
|
|
|
|
|
|
df = (
|
|
|
|
self.spark
|
|
|
|
.createDataFrame(
|
|
|
|
[(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
|
|
|
|
.select(
|
|
|
|
add_one("one"), add_two("one"),
|
|
|
|
to_upper("Foo"), to_lower("Foo"),
|
|
|
|
substr("foobar", lit(0), lit(3)),
|
|
|
|
trunc("float"), as_double("one")))
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
[tpe for _, tpe in df.dtypes],
|
|
|
|
["int", "double", "string", "string", "string", "bigint", "double"]
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
list(df.first()),
|
|
|
|
[2, 3.0, "FOO", "foo", "foo", 3, 1.0]
|
|
|
|
)
|
|
|
|
|
2017-02-24 11:22:30 -05:00
|
|
|
def test_udf_wrapper(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
from pyspark.sql.types import IntegerType
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
"""Identity"""
|
|
|
|
return x
|
|
|
|
|
|
|
|
return_type = IntegerType()
|
|
|
|
f_ = udf(f, return_type)
|
|
|
|
|
|
|
|
self.assertTrue(f.__doc__ in f_.__doc__)
|
|
|
|
self.assertEqual(f, f_.func)
|
|
|
|
self.assertEqual(return_type, f_.returnType)
|
|
|
|
|
[SPARK-21394][SPARK-21432][PYTHON] Reviving callable object/partial function support in UDF in PySpark
## What changes were proposed in this pull request?
This PR proposes to avoid `__name__` in the tuple naming the attributes assigned directly from the wrapped function to the wrapper function, and use `self._name` (`func.__name__` or `obj.__class__.name__`).
After SPARK-19161, we happened to break callable objects as UDFs in Python as below:
```python
from pyspark.sql import functions
class F(object):
def __call__(self, x):
return x
foo = F()
udf = functions.udf(foo)
```
```
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/functions.py", line 2142, in udf
return _udf(f=f, returnType=returnType)
File ".../spark/python/pyspark/sql/functions.py", line 2133, in _udf
return udf_obj._wrapped()
File ".../spark/python/pyspark/sql/functions.py", line 2090, in _wrapped
functools.wraps(self.func)
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/functools.py", line 33, in update_wrapper
setattr(wrapper, attr, getattr(wrapped, attr))
AttributeError: F instance has no attribute '__name__'
```
This worked in Spark 2.1:
```python
from pyspark.sql import functions
class F(object):
def __call__(self, x):
return x
foo = F()
udf = functions.udf(foo)
spark.range(1).select(udf("id")).show()
```
```
+-----+
|F(id)|
+-----+
| 0|
+-----+
```
**After**
```python
from pyspark.sql import functions
class F(object):
def __call__(self, x):
return x
foo = F()
udf = functions.udf(foo)
spark.range(1).select(udf("id")).show()
```
```
+-----+
|F(id)|
+-----+
| 0|
+-----+
```
_In addition, we also happened to break partial functions as below_:
```python
from pyspark.sql import functions
from functools import partial
partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
```
```
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/functions.py", line 2154, in udf
return _udf(f=f, returnType=returnType)
File ".../spark/python/pyspark/sql/functions.py", line 2145, in _udf
return udf_obj._wrapped()
File ".../spark/python/pyspark/sql/functions.py", line 2099, in _wrapped
functools.wraps(self.func, assigned=assignments)
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/functools.py", line 33, in update_wrapper
setattr(wrapper, attr, getattr(wrapped, attr))
AttributeError: 'functools.partial' object has no attribute '__module__'
```
This worked in Spark 2.1:
```python
from pyspark.sql import functions
from functools import partial
partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
spark.range(1).select(udf()).show()
```
```
+---------+
|partial()|
+---------+
| 1|
+---------+
```
**After**
```python
from pyspark.sql import functions
from functools import partial
partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
spark.range(1).select(udf()).show()
```
```
+---------+
|partial()|
+---------+
| 1|
+---------+
```
## How was this patch tested?
Unit tests in `python/pyspark/sql/tests.py` and manual tests.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #18615 from HyukjinKwon/callable-object.
2017-07-17 03:37:36 -04:00
|
|
|
class F(object):
|
|
|
|
"""Identity"""
|
|
|
|
def __call__(self, x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
f = F()
|
|
|
|
return_type = IntegerType()
|
|
|
|
f_ = udf(f, return_type)
|
|
|
|
|
|
|
|
self.assertTrue(f.__doc__ in f_.__doc__)
|
|
|
|
self.assertEqual(f, f_.func)
|
|
|
|
self.assertEqual(return_type, f_.returnType)
|
|
|
|
|
|
|
|
f = functools.partial(f, x=1)
|
|
|
|
return_type = IntegerType()
|
|
|
|
f_ = udf(f, return_type)
|
|
|
|
|
|
|
|
self.assertTrue(f.__doc__ in f_.__doc__)
|
|
|
|
self.assertEqual(f, f_.func)
|
|
|
|
self.assertEqual(return_type, f_.returnType)
|
|
|
|
|
[SPARK-19165][PYTHON][SQL] PySpark APIs using columns as arguments should validate input types for column
## What changes were proposed in this pull request?
While preparing to take over https://github.com/apache/spark/pull/16537, I realised a (I think) better approach to make the exception handling in one point.
This PR proposes to fix `_to_java_column` in `pyspark.sql.column`, which most of functions in `functions.py` and some other APIs use. This `_to_java_column` basically looks not working with other types than `pyspark.sql.column.Column` or string (`str` and `unicode`).
If this is not `Column`, then it calls `_create_column_from_name` which calls `functions.col` within JVM:
https://github.com/apache/spark/blob/42b9eda80e975d970c3e8da4047b318b83dd269f/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L76
And it looks we only have `String` one with `col`.
So, these should work:
```python
>>> from pyspark.sql.column import _to_java_column, Column
>>> _to_java_column("a")
JavaObject id=o28
>>> _to_java_column(u"a")
JavaObject id=o29
>>> _to_java_column(spark.range(1).id)
JavaObject id=o33
```
whereas these do not:
```python
>>> _to_java_column(1)
```
```
...
py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.lang.Integer]) does not exist
...
```
```python
>>> _to_java_column([])
```
```
...
py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.util.ArrayList]) does not exist
...
```
```python
>>> class A(): pass
>>> _to_java_column(A())
```
```
...
AttributeError: 'A' object has no attribute '_get_object_id'
```
Meaning most of functions using `_to_java_column` such as `udf` or `to_json` or some other APIs throw an exception as below:
```python
>>> from pyspark.sql.functions import udf
>>> udf(lambda x: x)(None)
```
```
...
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.col.
: java.lang.NullPointerException
...
```
```python
>>> from pyspark.sql.functions import to_json
>>> to_json(None)
```
```
...
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.col.
: java.lang.NullPointerException
...
```
**After this PR**:
```python
>>> from pyspark.sql.functions import udf
>>> udf(lambda x: x)(None)
...
```
```
TypeError: Invalid argument, not a string or column: None of type <type 'NoneType'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' functions.
```
```python
>>> from pyspark.sql.functions import to_json
>>> to_json(None)
```
```
...
TypeError: Invalid argument, not a string or column: None of type <type 'NoneType'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' functions.
```
## How was this patch tested?
Unit tests added in `python/pyspark/sql/tests.py` and manual tests.
Author: hyukjinkwon <gurwls223@gmail.com>
Author: zero323 <zero323@users.noreply.github.com>
Closes #19027 from HyukjinKwon/SPARK-19165.
2017-08-24 07:29:03 -04:00
|
|
|
def test_validate_column_types(self):
|
|
|
|
from pyspark.sql.functions import udf, to_json
|
|
|
|
from pyspark.sql.column import _to_java_column
|
|
|
|
|
|
|
|
self.assertTrue("Column" in _to_java_column("a").getClass().toString())
|
|
|
|
self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
|
|
|
|
self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
|
|
|
|
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
TypeError,
|
|
|
|
"Invalid argument, not a string or column",
|
|
|
|
lambda: _to_java_column(1))
|
|
|
|
|
|
|
|
class A():
|
|
|
|
pass
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, lambda: _to_java_column(A()))
|
|
|
|
self.assertRaises(TypeError, lambda: _to_java_column([]))
|
|
|
|
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
TypeError,
|
|
|
|
"Invalid argument, not a string or column",
|
|
|
|
lambda: udf(lambda x: x)(None))
|
|
|
|
self.assertRaises(TypeError, lambda: to_json(1))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_basic_functions(self):
|
|
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.read.json(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
df.count()
|
|
|
|
df.collect()
|
2015-02-14 02:03:22 -05:00
|
|
|
df.schema
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
# cache and checkpoint
|
|
|
|
self.assertFalse(df.is_cached)
|
|
|
|
df.persist()
|
2016-04-19 20:29:28 -04:00
|
|
|
df.unpersist(True)
|
2015-02-03 19:01:56 -05:00
|
|
|
df.cache()
|
|
|
|
self.assertTrue(df.is_cached)
|
|
|
|
self.assertEqual(2, df.count())
|
|
|
|
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("temp")
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.sql("select foo from temp")
|
2015-02-03 19:01:56 -05:00
|
|
|
df.count()
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
def test_apply_schema_to_row(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
|
|
|
|
df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(df.collect(), df2.collect())
|
|
|
|
|
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
|
2016-05-11 14:24:16 -04:00
|
|
|
df3 = self.spark.createDataFrame(rdd, df.schema)
|
2015-12-30 14:14:47 -05:00
|
|
|
self.assertEqual(10, df3.count())
|
|
|
|
|
|
|
|
def test_infer_schema_to_local(self):
|
|
|
|
input = [{"a": 1}, {"b": "coffee"}]
|
|
|
|
rdd = self.sc.parallelize(input)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(input)
|
|
|
|
df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
|
2015-12-30 14:14:47 -05:00
|
|
|
self.assertEqual(df.schema, df2.schema)
|
|
|
|
|
2016-01-03 20:04:35 -05:00
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
|
2016-05-11 14:24:16 -04:00
|
|
|
df3 = self.spark.createDataFrame(rdd, df.schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(10, df3.count())
|
|
|
|
|
2016-08-15 15:41:27 -04:00
|
|
|
def test_apply_schema_to_dict_and_rows(self):
|
|
|
|
schema = StructType().add("b", StringType()).add("a", IntegerType())
|
|
|
|
input = [{"a": 1}, {"b": "coffee"}]
|
|
|
|
rdd = self.sc.parallelize(input)
|
|
|
|
for verify in [False, True]:
|
|
|
|
df = self.spark.createDataFrame(input, schema, verifySchema=verify)
|
|
|
|
df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
|
|
|
|
self.assertEqual(df.schema, df2.schema)
|
|
|
|
|
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
|
|
|
|
df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
|
|
|
|
self.assertEqual(10, df3.count())
|
|
|
|
input = [Row(a=x, b=str(x)) for x in range(10)]
|
|
|
|
df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
|
|
|
|
self.assertEqual(10, df4.count())
|
|
|
|
|
2016-01-24 22:40:34 -05:00
|
|
|
def test_create_dataframe_schema_mismatch(self):
|
|
|
|
input = [Row(a=1)]
|
|
|
|
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
|
|
|
|
schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd, schema)
|
2016-03-08 17:00:03 -05:00
|
|
|
self.assertRaises(Exception, lambda: df.show())
|
2016-01-24 22:40:34 -05:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_serialize_nested_array_and_map(self):
|
|
|
|
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
row = df.head()
|
|
|
|
self.assertEqual(1, len(row.l))
|
|
|
|
self.assertEqual(1, row.l[0].a)
|
|
|
|
self.assertEqual("2", row.d["key"].d)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
l = df.rdd.map(lambda x: x.l).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, len(l))
|
|
|
|
self.assertEqual('s', l[0].b)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
d = df.rdd.map(lambda x: x.d).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, len(d))
|
|
|
|
self.assertEqual(1.0, d["key"].c)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
row = df.rdd.map(lambda x: x.d["key"]).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1.0, row.c)
|
|
|
|
self.assertEqual("2", row.d)
|
|
|
|
|
|
|
|
def test_infer_schema(self):
|
2015-02-20 18:35:05 -05:00
|
|
|
d = [Row(l=[], d={}, s=None),
|
2015-02-03 19:01:56 -05:00
|
|
|
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd)
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual([], df.rdd.map(lambda r: r.l).first())
|
|
|
|
self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("test")
|
2016-05-11 14:24:16 -04:00
|
|
|
result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'")
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, result.head()[0])
|
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
|
2015-02-14 02:03:22 -05:00
|
|
|
self.assertEqual(df.schema, df2.schema)
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
|
|
|
|
self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
|
2016-05-17 21:01:59 -04:00
|
|
|
df2.createOrReplaceTempView("test2")
|
2016-05-11 14:24:16 -04:00
|
|
|
result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, result.head()[0])
|
|
|
|
|
2018-01-08 00:32:05 -05:00
|
|
|
def test_infer_schema_not_enough_names(self):
|
|
|
|
df = self.spark.createDataFrame([["a", "b"]], ["col1"])
|
|
|
|
self.assertEqual(df.columns, ['col1', '_2'])
|
|
|
|
|
|
|
|
def test_infer_schema_fails(self):
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'field a'):
|
|
|
|
self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
|
|
|
|
schema=["a", "b"], samplingRatio=0.99)
|
|
|
|
|
2015-02-24 23:51:55 -05:00
|
|
|
def test_infer_nested_schema(self):
|
|
|
|
NestedRow = Row("f1", "f2")
|
|
|
|
nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
|
|
|
|
NestedRow([2, 3], {"row2": 2.0})])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(nestedRdd1)
|
2015-02-24 23:51:55 -05:00
|
|
|
self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
|
|
|
|
|
|
|
|
nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
|
|
|
|
NestedRow([[2, 3], [3, 4]], [2, 3])])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(nestedRdd2)
|
2015-02-24 23:51:55 -05:00
|
|
|
self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
|
|
|
|
|
|
|
|
from collections import namedtuple
|
|
|
|
CustomRow = namedtuple('CustomRow', 'field1 field2')
|
|
|
|
rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
|
|
|
|
CustomRow(field1=2, field2="row2"),
|
|
|
|
CustomRow(field1=3, field2="row3")])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd)
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
|
2015-02-24 23:51:55 -05:00
|
|
|
|
2018-01-08 00:32:05 -05:00
|
|
|
def test_create_dataframe_from_dict_respects_schema(self):
|
|
|
|
df = self.spark.createDataFrame([{'a': 1}], ["b"])
|
|
|
|
self.assertEqual(df.columns, ['b'])
|
|
|
|
|
2015-08-26 19:04:44 -04:00
|
|
|
def test_create_dataframe_from_objects(self):
|
|
|
|
data = [MyObject(1, "1"), MyObject(2, "2")]
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(data)
|
2015-08-26 19:04:44 -04:00
|
|
|
self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
|
|
|
|
self.assertEqual(df.first(), Row(key=1, value="1"))
|
|
|
|
|
2015-07-20 15:00:48 -04:00
|
|
|
def test_select_null_literal(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.sql("select null as col")
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(Row(col=None), df.first())
|
2015-07-20 15:00:48 -04:00
|
|
|
|
2015-02-24 23:51:55 -05:00
|
|
|
def test_apply_schema(self):
|
|
|
|
from datetime import date, datetime
|
2015-04-16 19:20:57 -04:00
|
|
|
rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
|
2015-02-24 23:51:55 -05:00
|
|
|
date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
|
|
|
|
{"a": 1}, (2,), [1, 2, 3], None)])
|
|
|
|
schema = StructType([
|
|
|
|
StructField("byte1", ByteType(), False),
|
|
|
|
StructField("byte2", ByteType(), False),
|
|
|
|
StructField("short1", ShortType(), False),
|
|
|
|
StructField("short2", ShortType(), False),
|
|
|
|
StructField("int1", IntegerType(), False),
|
|
|
|
StructField("float1", FloatType(), False),
|
|
|
|
StructField("date1", DateType(), False),
|
|
|
|
StructField("time1", TimestampType(), False),
|
|
|
|
StructField("map1", MapType(StringType(), IntegerType(), False), False),
|
|
|
|
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
|
|
|
|
StructField("list1", ArrayType(ByteType(), False), False),
|
|
|
|
StructField("null1", DoubleType(), True)])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd, schema)
|
2016-03-02 18:26:34 -05:00
|
|
|
results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
|
|
|
|
x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
|
2015-02-24 23:51:55 -05:00
|
|
|
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
|
|
|
|
datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
|
|
|
|
self.assertEqual(r, results.first())
|
|
|
|
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("table2")
|
2016-05-11 14:24:16 -04:00
|
|
|
r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
|
|
|
|
"short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
|
|
|
|
"float1 + 1.5 as float1 FROM table2").first()
|
2015-02-24 23:51:55 -05:00
|
|
|
|
|
|
|
self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_struct_in_map(self):
|
|
|
|
d = [Row(m={Row(i=1): Row(s="")})]
|
2015-02-14 02:03:22 -05:00
|
|
|
df = self.sc.parallelize(d).toDF()
|
2015-04-16 19:20:57 -04:00
|
|
|
k, v = list(df.head().m.items())[0]
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, k.i)
|
|
|
|
self.assertEqual("", v.s)
|
|
|
|
|
|
|
|
def test_convert_row_to_dict(self):
|
|
|
|
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
|
|
|
|
self.assertEqual(1, row.asDict()['l'][0].a)
|
2015-02-14 02:03:22 -05:00
|
|
|
df = self.sc.parallelize([row]).toDF()
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("test")
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.sql("select l, d from test").head()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, row.asDict()["l"][0].a)
|
|
|
|
self.assertEqual(1.0, row.asDict()['d']['key'].c)
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
def test_udt(self):
|
[SPARK-19507][SPARK-21296][PYTHON] Avoid per-record type dispatch in schema verification and improve exception message
## What changes were proposed in this pull request?
**Context**
While reviewing https://github.com/apache/spark/pull/17227, I realised here we type-dispatch per record. The PR itself is fine in terms of performance as is but this prints a prefix, `"obj"` in exception message as below:
```
from pyspark.sql.types import *
schema = StructType([StructField('s', IntegerType(), nullable=False)])
spark.createDataFrame([["1"]], schema)
...
TypeError: obj.s: IntegerType can not accept object '1' in type <type 'str'>
```
I suggested to get rid of this but during investigating this, I realised my approach might bring a performance regression as it is a hot path.
Only for SPARK-19507 and https://github.com/apache/spark/pull/17227, It needs more changes to cleanly get rid of the prefix and I rather decided to fix both issues together.
**Propersal**
This PR tried to
- get rid of per-record type dispatch as we do in many code paths in Scala so that it improves the performance (roughly ~25% improvement) - SPARK-21296
This was tested with a simple code `spark.createDataFrame(range(1000000), "int")`. However, I am quite sure the actual improvement in practice is larger than this, in particular, when the schema is complicated.
- improve error message in exception describing field information as prose - SPARK-19507
## How was this patch tested?
Manually tested and unit tests were added in `python/pyspark/sql/tests.py`.
Benchmark - codes: https://gist.github.com/HyukjinKwon/c3397469c56cb26c2d7dd521ed0bc5a3
Error message - codes: https://gist.github.com/HyukjinKwon/b1b2c7f65865444c4a8836435100e398
**Before**
Benchmark:
- Results: https://gist.github.com/HyukjinKwon/4a291dab45542106301a0c1abcdca924
Error message
- Results: https://gist.github.com/HyukjinKwon/57b1916395794ce924faa32b14a3fe19
**After**
Benchmark
- Results: https://gist.github.com/HyukjinKwon/21496feecc4a920e50c4e455f836266e
Error message
- Results: https://gist.github.com/HyukjinKwon/7a494e4557fe32a652ce1236e504a395
Closes #17227
Author: hyukjinkwon <gurwls223@gmail.com>
Author: David Gingrich <david@textio.com>
Closes #18521 from HyukjinKwon/python-type-dispatch.
2017-07-04 08:45:58 -04:00
|
|
|
from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier
|
2015-07-30 01:30:49 -04:00
|
|
|
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
|
|
|
|
|
|
|
|
def check_datatype(datatype):
|
|
|
|
pickled = pickle.loads(pickle.dumps(datatype))
|
|
|
|
assert datatype == pickled
|
2016-08-25 02:36:04 -04:00
|
|
|
scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
|
2015-07-30 01:30:49 -04:00
|
|
|
python_datatype = _parse_datatype_json_string(scala_datatype.json())
|
|
|
|
assert datatype == python_datatype
|
|
|
|
|
|
|
|
check_datatype(ExamplePointUDT())
|
|
|
|
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
|
|
|
check_datatype(structtype_with_udt)
|
|
|
|
p = ExamplePoint(1.0, 2.0)
|
|
|
|
self.assertEqual(_infer_type(p), ExamplePointUDT())
|
[SPARK-19507][SPARK-21296][PYTHON] Avoid per-record type dispatch in schema verification and improve exception message
## What changes were proposed in this pull request?
**Context**
While reviewing https://github.com/apache/spark/pull/17227, I realised here we type-dispatch per record. The PR itself is fine in terms of performance as is but this prints a prefix, `"obj"` in exception message as below:
```
from pyspark.sql.types import *
schema = StructType([StructField('s', IntegerType(), nullable=False)])
spark.createDataFrame([["1"]], schema)
...
TypeError: obj.s: IntegerType can not accept object '1' in type <type 'str'>
```
I suggested to get rid of this but during investigating this, I realised my approach might bring a performance regression as it is a hot path.
Only for SPARK-19507 and https://github.com/apache/spark/pull/17227, It needs more changes to cleanly get rid of the prefix and I rather decided to fix both issues together.
**Propersal**
This PR tried to
- get rid of per-record type dispatch as we do in many code paths in Scala so that it improves the performance (roughly ~25% improvement) - SPARK-21296
This was tested with a simple code `spark.createDataFrame(range(1000000), "int")`. However, I am quite sure the actual improvement in practice is larger than this, in particular, when the schema is complicated.
- improve error message in exception describing field information as prose - SPARK-19507
## How was this patch tested?
Manually tested and unit tests were added in `python/pyspark/sql/tests.py`.
Benchmark - codes: https://gist.github.com/HyukjinKwon/c3397469c56cb26c2d7dd521ed0bc5a3
Error message - codes: https://gist.github.com/HyukjinKwon/b1b2c7f65865444c4a8836435100e398
**Before**
Benchmark:
- Results: https://gist.github.com/HyukjinKwon/4a291dab45542106301a0c1abcdca924
Error message
- Results: https://gist.github.com/HyukjinKwon/57b1916395794ce924faa32b14a3fe19
**After**
Benchmark
- Results: https://gist.github.com/HyukjinKwon/21496feecc4a920e50c4e455f836266e
Error message
- Results: https://gist.github.com/HyukjinKwon/7a494e4557fe32a652ce1236e504a395
Closes #17227
Author: hyukjinkwon <gurwls223@gmail.com>
Author: David Gingrich <david@textio.com>
Closes #18521 from HyukjinKwon/python-type-dispatch.
2017-07-04 08:45:58 -04:00
|
|
|
_make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
|
|
|
|
self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))
|
2015-07-30 01:30:49 -04:00
|
|
|
|
|
|
|
check_datatype(PythonOnlyUDT())
|
|
|
|
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", PythonOnlyUDT(), False)])
|
|
|
|
check_datatype(structtype_with_udt)
|
|
|
|
p = PythonOnlyPoint(1.0, 2.0)
|
|
|
|
self.assertEqual(_infer_type(p), PythonOnlyUDT())
|
[SPARK-19507][SPARK-21296][PYTHON] Avoid per-record type dispatch in schema verification and improve exception message
## What changes were proposed in this pull request?
**Context**
While reviewing https://github.com/apache/spark/pull/17227, I realised here we type-dispatch per record. The PR itself is fine in terms of performance as is but this prints a prefix, `"obj"` in exception message as below:
```
from pyspark.sql.types import *
schema = StructType([StructField('s', IntegerType(), nullable=False)])
spark.createDataFrame([["1"]], schema)
...
TypeError: obj.s: IntegerType can not accept object '1' in type <type 'str'>
```
I suggested to get rid of this but during investigating this, I realised my approach might bring a performance regression as it is a hot path.
Only for SPARK-19507 and https://github.com/apache/spark/pull/17227, It needs more changes to cleanly get rid of the prefix and I rather decided to fix both issues together.
**Propersal**
This PR tried to
- get rid of per-record type dispatch as we do in many code paths in Scala so that it improves the performance (roughly ~25% improvement) - SPARK-21296
This was tested with a simple code `spark.createDataFrame(range(1000000), "int")`. However, I am quite sure the actual improvement in practice is larger than this, in particular, when the schema is complicated.
- improve error message in exception describing field information as prose - SPARK-19507
## How was this patch tested?
Manually tested and unit tests were added in `python/pyspark/sql/tests.py`.
Benchmark - codes: https://gist.github.com/HyukjinKwon/c3397469c56cb26c2d7dd521ed0bc5a3
Error message - codes: https://gist.github.com/HyukjinKwon/b1b2c7f65865444c4a8836435100e398
**Before**
Benchmark:
- Results: https://gist.github.com/HyukjinKwon/4a291dab45542106301a0c1abcdca924
Error message
- Results: https://gist.github.com/HyukjinKwon/57b1916395794ce924faa32b14a3fe19
**After**
Benchmark
- Results: https://gist.github.com/HyukjinKwon/21496feecc4a920e50c4e455f836266e
Error message
- Results: https://gist.github.com/HyukjinKwon/7a494e4557fe32a652ce1236e504a395
Closes #17227
Author: hyukjinkwon <gurwls223@gmail.com>
Author: David Gingrich <david@textio.com>
Closes #18521 from HyukjinKwon/python-type-dispatch.
2017-07-04 08:45:58 -04:00
|
|
|
_make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
|
|
|
|
self.assertRaises(
|
|
|
|
ValueError,
|
|
|
|
lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
|
2015-07-30 01:30:49 -04:00
|
|
|
|
[SPARK-16062] [SPARK-15989] [SQL] Fix two bugs of Python-only UDTs
## What changes were proposed in this pull request?
There are two related bugs of Python-only UDTs. Because the test case of second one needs the first fix too. I put them into one PR. If it is not appropriate, please let me know.
### First bug: When MapObjects works on Python-only UDTs
`RowEncoder` will use `PythonUserDefinedType.sqlType` for its deserializer expression. If the sql type is `ArrayType`, we will have `MapObjects` working on it. But `MapObjects` doesn't consider `PythonUserDefinedType` as its input data type. It causes error like:
import pyspark.sql.group
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql.types import *
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
df = spark.createDataFrame([(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema)
df.show()
File "/home/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o36.showString.
: java.lang.RuntimeException: Error while decoding: scala.MatchError: org.apache.spark.sql.types.PythonUserDefinedTypef4ceede8 (of class org.apache.spark.sql.types.PythonUserDefinedType)
...
### Second bug: When Python-only UDTs is the element type of ArrayType
import pyspark.sql.group
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql.types import *
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
df = spark.createDataFrame([(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], schema=schema)
df.show()
## How was this patch tested?
PySpark's sql tests.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #13778 from viirya/fix-pyudt.
2016-08-02 13:08:18 -04:00
|
|
|
def test_simple_udt_in_df(self):
|
|
|
|
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
|
|
|
|
schema=schema)
|
|
|
|
df.show()
|
|
|
|
|
|
|
|
def test_nested_udt_in_df(self):
|
|
|
|
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
|
|
|
|
schema=schema)
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
schema = StructType().add("key", LongType()).add("val",
|
|
|
|
MapType(LongType(), PythonOnlyUDT()))
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
|
|
|
|
schema=schema)
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
def test_complex_nested_udt_in_df(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
|
|
|
|
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
|
|
|
|
schema=schema)
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
gd = df.groupby("key").agg({"val": "collect_list"})
|
|
|
|
gd.collect()
|
|
|
|
udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
|
|
|
|
gd.select(udf(*gd)).collect()
|
|
|
|
|
2016-06-28 17:09:38 -04:00
|
|
|
def test_udt_with_none(self):
|
|
|
|
df = self.spark.range(0, 10, 1, 1)
|
|
|
|
|
|
|
|
def myudf(x):
|
|
|
|
if x > 0:
|
|
|
|
return PythonOnlyPoint(float(x), float(x))
|
|
|
|
|
|
|
|
self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
|
|
|
|
rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
|
|
|
|
self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
|
|
|
|
|
2018-01-23 21:43:48 -05:00
|
|
|
def test_nonparam_udf_with_aggregate(self):
|
|
|
|
import pyspark.sql.functions as f
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame([(1, 2), (1, 2)])
|
|
|
|
f_udf = f.udf(lambda: "const_str")
|
|
|
|
rows = df.distinct().withColumn("a", f_udf()).collect()
|
|
|
|
self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_infer_schema_with_udt(self):
|
2015-02-09 23:49:22 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-02-14 02:03:22 -05:00
|
|
|
schema = df.schema
|
2015-02-03 19:01:56 -05:00
|
|
|
field = [f for f in schema.fields if f.name == "point"][0]
|
|
|
|
self.assertEqual(type(field.dataType), ExamplePointUDT)
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("labeled_point")
|
2016-05-11 14:24:16 -04:00
|
|
|
point = self.spark.sql("SELECT point FROM labeled_point").head().point
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-07-30 01:30:49 -04:00
|
|
|
schema = df.schema
|
|
|
|
field = [f for f in schema.fields if f.name == "point"][0]
|
|
|
|
self.assertEqual(type(field.dataType), PythonOnlyUDT)
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("labeled_point")
|
2016-05-11 14:24:16 -04:00
|
|
|
point = self.spark.sql("SELECT point FROM labeled_point").head().point
|
2015-07-30 01:30:49 -04:00
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_apply_schema_with_udt(self):
|
2015-02-09 23:49:22 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = (1.0, ExamplePoint(1.0, 2.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row], schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
point = df.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = (1.0, PythonOnlyPoint(1.0, 2.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", PythonOnlyUDT(), False)])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row], schema)
|
2015-07-30 01:30:49 -04:00
|
|
|
point = df.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
|
2015-07-09 17:43:38 -04:00
|
|
|
def test_udf_with_udt(self):
|
2015-07-20 15:14:47 -04:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-07-09 17:43:38 -04:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
|
2015-07-09 17:43:38 -04:00
|
|
|
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
|
|
|
|
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
|
2015-07-20 15:14:47 -04:00
|
|
|
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
|
|
|
|
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
|
2015-07-09 17:43:38 -04:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
|
2015-07-30 01:30:49 -04:00
|
|
|
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
|
|
|
|
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
|
|
|
|
udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
|
|
|
|
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_parquet_with_udt(self):
|
2015-07-30 01:30:49 -04:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df0 = self.spark.createDataFrame([row])
|
2015-02-03 19:01:56 -05:00
|
|
|
output_dir = os.path.join(self.tempdir.name, "labeled_point")
|
2015-07-30 01:30:49 -04:00
|
|
|
df0.write.parquet(output_dir)
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.read.parquet(output_dir)
|
2015-02-03 19:01:56 -05:00
|
|
|
point = df1.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df0 = self.spark.createDataFrame([row])
|
2015-07-30 01:30:49 -04:00
|
|
|
df0.write.parquet(output_dir, mode='overwrite')
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.read.parquet(output_dir)
|
2015-07-30 01:30:49 -04:00
|
|
|
point = df1.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
|
2016-03-25 01:34:55 -04:00
|
|
|
def test_union_with_udt(self):
|
2016-02-21 19:58:17 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
|
|
|
row1 = (1.0, ExamplePoint(1.0, 2.0))
|
|
|
|
row2 = (2.0, ExamplePoint(3.0, 4.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.createDataFrame([row1], schema)
|
|
|
|
df2 = self.spark.createDataFrame([row2], schema)
|
2016-02-21 19:58:17 -05:00
|
|
|
|
2016-03-25 01:34:55 -04:00
|
|
|
result = df1.union(df2).orderBy("label").collect()
|
2016-02-21 19:58:17 -05:00
|
|
|
self.assertEqual(
|
|
|
|
result,
|
|
|
|
[
|
|
|
|
Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
|
|
|
|
Row(label=2.0, point=ExamplePoint(3.0, 4.0))
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2018-01-18 22:37:08 -05:00
|
|
|
def test_cast_to_string_with_udt(self):
|
|
|
|
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
|
|
|
|
from pyspark.sql.functions import col
|
|
|
|
row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
|
|
|
|
schema = StructType([StructField("point", ExamplePointUDT(), False),
|
|
|
|
StructField("pypoint", PythonOnlyUDT(), False)])
|
|
|
|
df = self.spark.createDataFrame([row], schema)
|
|
|
|
|
|
|
|
result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
|
|
|
|
self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_column_operators(self):
|
|
|
|
ci = self.df.key
|
|
|
|
cs = self.df.value
|
|
|
|
c = ci == cs
|
|
|
|
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
|
2015-09-11 18:19:04 -04:00
|
|
|
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in rcc))
|
2015-06-23 18:51:16 -04:00
|
|
|
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in cb))
|
|
|
|
cbool = (ci & ci), (ci | ci), (~ci)
|
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in cbool))
|
2017-02-23 16:22:39 -05:00
|
|
|
css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\
|
2017-05-01 12:43:32 -04:00
|
|
|
cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in css))
|
|
|
|
self.assertTrue(isinstance(ci.cast(LongType()), Column))
|
[SPARK-19701][SQL][PYTHON] Throws a correct exception for 'in' operator against column
## What changes were proposed in this pull request?
This PR proposes to remove incorrect implementation that has been not executed so far (at least from Spark 1.5.2) for `in` operator and throw a correct exception rather than saying it is a bool. I tested the codes above in 1.5.2, 1.6.3, 2.1.0 and in the master branch as below:
**1.5.2**
```python
>>> df = sqlContext.createDataFrame([[1]])
>>> 1 in df._1
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark-1.5.2-bin-hadoop2.6/python/pyspark/sql/column.py", line 418, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
**1.6.3**
```python
>>> 1 in sqlContext.range(1).id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark-1.6.3-bin-hadoop2.6/python/pyspark/sql/column.py", line 447, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
**2.1.0**
```python
>>> 1 in spark.range(1).id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark-2.1.0-bin-hadoop2.7/python/pyspark/sql/column.py", line 426, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
**Current Master**
```python
>>> 1 in spark.range(1).id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/column.py", line 452, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
**After**
```python
>>> 1 in spark.range(1).id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/column.py", line 184, in __contains__
raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' "
ValueError: Cannot apply 'in' operator against a column: please use 'contains' in a string column or 'array_contains' function for an array column.
```
In more details,
It seems the implementation intended to support this
```python
1 in df.column
```
However, currently, it throws an exception as below:
```python
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/column.py", line 426, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
What happens here is as below:
```python
class Column(object):
def __contains__(self, item):
print "I am contains"
return Column()
def __nonzero__(self):
raise Exception("I am nonzero.")
>>> 1 in Column()
I am contains
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 6, in __nonzero__
Exception: I am nonzero.
```
It seems it calls `__contains__` first and then `__nonzero__` or `__bool__` is being called against `Column()` to make this a bool (or int to be specific).
It seems `__nonzero__` (for Python 2), `__bool__` (for Python 3) and `__contains__` forcing the the return into a bool unlike other operators. There are few references about this as below:
https://bugs.python.org/issue16011
http://stackoverflow.com/questions/12244074/python-source-code-for-built-in-in-operator/12244378#12244378
http://stackoverflow.com/questions/38542543/functionality-of-python-in-vs-contains/38542777
It seems we can't overwrite `__nonzero__` or `__bool__` as a workaround to make this working because these force the return type as a bool as below:
```python
class Column(object):
def __contains__(self, item):
print "I am contains"
return Column()
def __nonzero__(self):
return "a"
>>> 1 in Column()
I am contains
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: __nonzero__ should return bool or int, returned str
```
## How was this patch tested?
Added unit tests in `tests.py`.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #17160 from HyukjinKwon/SPARK-19701.
2017-03-05 21:04:52 -05:00
|
|
|
self.assertRaisesRegexp(ValueError,
|
|
|
|
"Cannot apply 'in' operator against a column",
|
|
|
|
lambda: 1 in cs)
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2017-02-13 18:23:56 -05:00
|
|
|
def test_column_getitem(self):
|
|
|
|
from pyspark.sql.functions import col
|
|
|
|
|
|
|
|
self.assertIsInstance(col("foo")[1:3], Column)
|
|
|
|
self.assertIsInstance(col("foo")[0], Column)
|
|
|
|
self.assertIsInstance(col("foo")["bar"], Column)
|
|
|
|
self.assertRaises(ValueError, lambda: col("foo")[0:10:2])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_column_select(self):
|
|
|
|
df = self.df
|
|
|
|
self.assertEqual(self.testData, df.select("*").collect())
|
|
|
|
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
|
|
|
|
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
|
|
|
|
|
2015-05-02 02:43:24 -04:00
|
|
|
def test_freqItems(self):
|
|
|
|
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
|
|
|
|
df = self.sc.parallelize(vals).toDF()
|
|
|
|
items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
|
|
|
|
self.assertTrue(1 in items[0])
|
|
|
|
self.assertTrue(-2.0 in items[1])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_aggregator(self):
|
|
|
|
df = self.df
|
|
|
|
g = df.groupBy()
|
|
|
|
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
|
|
|
|
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
|
|
|
|
|
2015-02-14 02:03:22 -05:00
|
|
|
from pyspark.sql import functions
|
|
|
|
self.assertEqual((0, u'99'),
|
|
|
|
tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
|
|
|
|
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
|
|
|
|
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2016-01-31 16:56:13 -05:00
|
|
|
def test_first_last_ignorenulls(self):
|
|
|
|
from pyspark.sql import functions
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.range(0, 100)
|
2016-01-31 16:56:13 -05:00
|
|
|
df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
|
|
|
|
df3 = df2.select(functions.first(df2.id, False).alias('a'),
|
|
|
|
functions.first(df2.id, True).alias('b'),
|
|
|
|
functions.last(df2.id, False).alias('c'),
|
|
|
|
functions.last(df2.id, True).alias('d'))
|
|
|
|
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
|
|
|
|
|
2016-02-25 02:15:36 -05:00
|
|
|
def test_approxQuantile(self):
|
2017-02-01 17:11:28 -05:00
|
|
|
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
|
2017-09-08 14:57:33 -04:00
|
|
|
for f in ["a", u"a"]:
|
|
|
|
aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
|
|
|
|
self.assertTrue(isinstance(aq, list))
|
|
|
|
self.assertEqual(len(aq), 3)
|
2016-02-25 02:15:36 -05:00
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aq))
|
2017-09-08 14:57:33 -04:00
|
|
|
aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
|
2017-02-01 17:11:28 -05:00
|
|
|
self.assertTrue(isinstance(aqs, list))
|
|
|
|
self.assertEqual(len(aqs), 2)
|
|
|
|
self.assertTrue(isinstance(aqs[0], list))
|
|
|
|
self.assertEqual(len(aqs[0]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
|
|
|
|
self.assertTrue(isinstance(aqs[1], list))
|
|
|
|
self.assertEqual(len(aqs[1]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
|
2017-09-08 14:57:33 -04:00
|
|
|
aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
|
2017-02-01 17:11:28 -05:00
|
|
|
self.assertTrue(isinstance(aqt, list))
|
|
|
|
self.assertEqual(len(aqt), 2)
|
|
|
|
self.assertTrue(isinstance(aqt[0], list))
|
|
|
|
self.assertEqual(len(aqt[0]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
|
|
|
|
self.assertTrue(isinstance(aqt[1], list))
|
|
|
|
self.assertEqual(len(aqt[1]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
|
2016-02-25 02:15:36 -05:00
|
|
|
|
2015-05-04 00:44:39 -04:00
|
|
|
def test_corr(self):
|
|
|
|
import math
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
|
2017-09-08 14:57:33 -04:00
|
|
|
corr = df.stat.corr(u"a", "b")
|
2015-05-04 00:44:39 -04:00
|
|
|
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
|
|
|
|
|
2017-09-08 14:57:33 -04:00
|
|
|
def test_sampleby(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
|
|
|
|
sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
|
|
|
|
self.assertTrue(sampled.count() == 3)
|
|
|
|
|
2015-05-01 16:29:17 -04:00
|
|
|
def test_cov(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
2017-09-08 14:57:33 -04:00
|
|
|
cov = df.stat.cov(u"a", "b")
|
2015-05-01 16:29:17 -04:00
|
|
|
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
|
|
|
|
|
2015-05-04 20:02:49 -04:00
|
|
|
def test_crosstab(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
|
2017-09-08 14:57:33 -04:00
|
|
|
ct = df.stat.crosstab(u"a", "b").collect()
|
2015-05-04 20:02:49 -04:00
|
|
|
ct = sorted(ct, key=lambda x: x[0])
|
|
|
|
for i, row in enumerate(ct):
|
|
|
|
self.assertEqual(row[0], str(i))
|
|
|
|
self.assertTrue(row[1], 1)
|
|
|
|
self.assertTrue(row[2], 1)
|
|
|
|
|
2015-04-29 03:09:24 -04:00
|
|
|
def test_math_functions(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
2015-05-06 01:56:01 -04:00
|
|
|
from pyspark.sql import functions
|
2015-04-29 03:09:24 -04:00
|
|
|
import math
|
|
|
|
|
|
|
|
def get_values(l):
|
|
|
|
return [j[0] for j in l]
|
|
|
|
|
|
|
|
def assert_close(a, b):
|
|
|
|
c = get_values(b)
|
|
|
|
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
|
|
|
|
return sum(diff) == len(a)
|
|
|
|
assert_close([math.cos(i) for i in range(10)],
|
|
|
|
df.select(functions.cos(df.a)).collect())
|
|
|
|
assert_close([math.cos(i) for i in range(10)],
|
|
|
|
df.select(functions.cos("a")).collect())
|
|
|
|
assert_close([math.sin(i) for i in range(10)],
|
|
|
|
df.select(functions.sin(df.a)).collect())
|
|
|
|
assert_close([math.sin(i) for i in range(10)],
|
|
|
|
df.select(functions.sin(df['a'])).collect())
|
|
|
|
assert_close([math.pow(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, df.b)).collect())
|
|
|
|
assert_close([math.pow(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, 2)).collect())
|
|
|
|
assert_close([math.pow(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, 2.0)).collect())
|
|
|
|
assert_close([math.hypot(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.hypot(df.a, df.b)).collect())
|
|
|
|
|
2015-05-01 00:56:03 -04:00
|
|
|
def test_rand_functions(self):
|
|
|
|
df = self.df
|
|
|
|
from pyspark.sql import functions
|
|
|
|
rnd = df.select('key', functions.rand()).collect()
|
|
|
|
for row in rnd:
|
|
|
|
assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
|
|
|
|
rndn = df.select('key', functions.randn(5)).collect()
|
|
|
|
for row in rndn:
|
|
|
|
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
|
|
|
|
|
2015-08-06 20:03:14 -04:00
|
|
|
# If the specified seed is 0, we should use it.
|
|
|
|
# https://issues.apache.org/jira/browse/SPARK-9691
|
|
|
|
rnd1 = df.select('key', functions.rand(0)).collect()
|
|
|
|
rnd2 = df.select('key', functions.rand(0)).collect()
|
|
|
|
self.assertEqual(sorted(rnd1), sorted(rnd2))
|
|
|
|
|
|
|
|
rndn1 = df.select('key', functions.randn(0)).collect()
|
|
|
|
rndn2 = df.select('key', functions.randn(0)).collect()
|
|
|
|
self.assertEqual(sorted(rndn1), sorted(rndn2))
|
|
|
|
|
2017-08-15 22:19:15 -04:00
|
|
|
def test_string_functions(self):
|
|
|
|
from pyspark.sql.functions import col, lit
|
|
|
|
df = self.spark.createDataFrame([['nick']], schema=['name'])
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
TypeError,
|
|
|
|
"must be the same type",
|
|
|
|
lambda: df.select(col('name').substr(0, lit(1))))
|
|
|
|
if sys.version_info.major == 2:
|
|
|
|
self.assertRaises(
|
|
|
|
TypeError,
|
|
|
|
lambda: df.select(col('name').substr(long(0), long(1))))
|
|
|
|
|
2017-03-26 21:40:00 -04:00
|
|
|
def test_array_contains_function(self):
|
|
|
|
from pyspark.sql.functions import array_contains
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
|
|
|
|
actual = df.select(array_contains(df.data, 1).alias('b')).collect()
|
|
|
|
# The value argument can be implicitly castable to the element's type of the array.
|
|
|
|
self.assertEqual([Row(b=True), Row(b=False)], actual)
|
|
|
|
|
2015-05-05 16:23:53 -04:00
|
|
|
def test_between_function(self):
|
|
|
|
df = self.sc.parallelize([
|
|
|
|
Row(a=1, b=2, c=3),
|
|
|
|
Row(a=2, b=1, c=3),
|
|
|
|
Row(a=4, b=1, c=4)]).toDF()
|
|
|
|
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
|
|
|
|
df.filter(df.a.between(df.b, df.c)).collect())
|
|
|
|
|
2015-06-29 17:15:15 -04:00
|
|
|
def test_struct_type(self):
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True),
|
|
|
|
StructField("f2", StringType(), True, None)])
|
2017-07-28 23:59:32 -04:00
|
|
|
self.assertEqual(struct1.fieldNames(), struct2.names)
|
2015-06-29 17:15:15 -04:00
|
|
|
self.assertEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True)])
|
2017-07-28 23:59:32 -04:00
|
|
|
self.assertNotEqual(struct1.fieldNames(), struct2.names)
|
2015-06-29 17:15:15 -04:00
|
|
|
self.assertNotEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = (StructType().add(StructField("f1", StringType(), True))
|
|
|
|
.add(StructField("f2", StringType(), True, None)))
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True),
|
|
|
|
StructField("f2", StringType(), True, None)])
|
2017-07-28 23:59:32 -04:00
|
|
|
self.assertEqual(struct1.fieldNames(), struct2.names)
|
2015-06-29 17:15:15 -04:00
|
|
|
self.assertEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = (StructType().add(StructField("f1", StringType(), True))
|
|
|
|
.add(StructField("f2", StringType(), True, None)))
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True)])
|
2017-07-28 23:59:32 -04:00
|
|
|
self.assertNotEqual(struct1.fieldNames(), struct2.names)
|
2015-06-29 17:15:15 -04:00
|
|
|
self.assertNotEqual(struct1, struct2)
|
|
|
|
|
|
|
|
# Catch exception raised during improper construction
|
2017-07-28 23:59:32 -04:00
|
|
|
self.assertRaises(ValueError, lambda: StructType().add("name"))
|
2016-04-20 16:45:14 -04:00
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
for field in struct1:
|
|
|
|
self.assertIsInstance(field, StructField)
|
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
self.assertEqual(len(struct1), 2)
|
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
self.assertIs(struct1["f1"], struct1.fields[0])
|
|
|
|
self.assertIs(struct1[0], struct1.fields[0])
|
|
|
|
self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
|
2017-07-28 23:59:32 -04:00
|
|
|
self.assertRaises(KeyError, lambda: struct1["f9"])
|
|
|
|
self.assertRaises(IndexError, lambda: struct1[9])
|
|
|
|
self.assertRaises(TypeError, lambda: struct1[9.9])
|
2015-06-29 17:15:15 -04:00
|
|
|
|
[SPARK-21365][PYTHON] Deduplicate logics parsing DDL type/schema definition
## What changes were proposed in this pull request?
This PR deals with four points as below:
- Reuse existing DDL parser APIs rather than reimplementing within PySpark
- Support DDL formatted string, `field type, field type`.
- Support case-insensitivity for parsing.
- Support nested data types as below:
**Before**
```
>>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show()
...
ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int>
```
```
>>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show()
...
ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int>
```
```
>>> spark.createDataFrame([[1]], "a int").show()
...
ValueError: Could not parse datatype: a int
```
**After**
```
>>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show()
+---+
| a|
+---+
|[1]|
+---+
```
```
>>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show()
+---+
| a|
+---+
|[1]|
+---+
```
```
>>> spark.createDataFrame([[1]], "a int").show()
+---+
| a|
+---+
| 1|
+---+
```
## How was this patch tested?
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #18590 from HyukjinKwon/deduplicate-python-ddl.
2017-07-11 10:03:10 -04:00
|
|
|
def test_parse_datatype_string(self):
|
|
|
|
from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
|
|
|
|
for k, t in _all_atomic_types.items():
|
|
|
|
if t != NullType:
|
|
|
|
self.assertEqual(t(), _parse_datatype_string(k))
|
|
|
|
self.assertEqual(IntegerType(), _parse_datatype_string("int"))
|
|
|
|
self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
|
|
|
|
self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
|
|
|
|
self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
|
|
|
|
self.assertEqual(
|
|
|
|
ArrayType(IntegerType()),
|
|
|
|
_parse_datatype_string("array<int >"))
|
|
|
|
self.assertEqual(
|
|
|
|
MapType(IntegerType(), DoubleType()),
|
|
|
|
_parse_datatype_string("map< int, double >"))
|
|
|
|
self.assertEqual(
|
|
|
|
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
|
|
|
|
_parse_datatype_string("struct<a:int, c:double >"))
|
|
|
|
self.assertEqual(
|
|
|
|
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
|
|
|
|
_parse_datatype_string("a:int, c:double"))
|
|
|
|
self.assertEqual(
|
|
|
|
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
|
|
|
|
_parse_datatype_string("a INT, c DOUBLE"))
|
|
|
|
|
2016-01-27 12:55:10 -05:00
|
|
|
def test_metadata_null(self):
|
|
|
|
schema = StructType([StructField("f1", StringType(), True, None),
|
|
|
|
StructField("f2", StringType(), True, {'a': None})])
|
|
|
|
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.createDataFrame(rdd, schema)
|
2016-01-27 12:55:10 -05:00
|
|
|
|
2015-02-10 20:29:52 -05:00
|
|
|
def test_save_and_load(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.json(tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath, schema)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.json(tmpPath, "overwrite")
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.save(format="json", mode="overwrite", path=tmpPath,
|
|
|
|
noUse="this options will not be used in save.")
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.load(format="json", path=tmpPath,
|
|
|
|
noUse="this options will not be used in load.")
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
|
2015-02-10 20:29:52 -05:00
|
|
|
"org.apache.spark.sql.parquet")
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
|
|
|
actual = self.spark.read.load(path=tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
2015-06-22 16:51:23 -04:00
|
|
|
|
2016-04-22 12:19:36 -04:00
|
|
|
csvpath = os.path.join(tempfile.mkdtemp(), 'data')
|
|
|
|
df.write.option('quote', None).format('csv').save(csvpath)
|
|
|
|
|
2015-06-22 16:51:23 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
|
|
|
def test_save_and_load_builder(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
df.write.json(tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath, schema)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.mode("overwrite").json(tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
|
2015-06-29 03:13:39 -04:00
|
|
|
.option("noUse", "this option will not be used in save.")\
|
2015-06-22 16:51:23 -04:00
|
|
|
.format("json").save(path=tmpPath)
|
|
|
|
actual =\
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.read.format("json")\
|
|
|
|
.load(path=tmpPath, noUse="this options will not be used in load.")
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
|
2015-06-22 16:51:23 -04:00
|
|
|
"org.apache.spark.sql.parquet")
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
|
|
|
actual = self.spark.read.load(path=tmpPath)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
[SPARK-19876][SS][WIP] OneTime Trigger Executor
## What changes were proposed in this pull request?
An additional trigger and trigger executor that will execute a single trigger only. One can use this OneTime trigger to have more control over the scheduling of triggers.
In addition, this patch requires an optimization to StreamExecution that logs a commit record at the end of successfully processing a batch. This new commit log will be used to determine the next batch (offsets) to process after a restart, instead of using the offset log itself to determine what batch to process next after restart; using the offset log to determine this would process the previously logged batch, always, thus not permitting a OneTime trigger feature.
## How was this patch tested?
A number of existing tests have been revised. These tests all assumed that when restarting a stream, the last batch in the offset log is to be re-processed. Given that we now have a commit log that will tell us if that last batch was processed successfully, the results/assumptions of those tests needed to be revised accordingly.
In addition, a OneTime trigger test was added to StreamingQuerySuite, which tests:
- The semantics of OneTime trigger (i.e., on start, execute a single batch, then stop).
- The case when the commit log was not able to successfully log the completion of a batch before restart, which would mean that we should fall back to what's in the offset log.
- A OneTime trigger execution that results in an exception being thrown.
marmbrus tdas zsxwing
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Tyson Condie <tcondie@gmail.com>
Author: Tathagata Das <tathagata.das1565@gmail.com>
Closes #17219 from tcondie/stream-commit.
2017-03-23 17:32:05 -04:00
|
|
|
def test_stream_trigger(self):
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
[SPARK-19876][SS][WIP] OneTime Trigger Executor
## What changes were proposed in this pull request?
An additional trigger and trigger executor that will execute a single trigger only. One can use this OneTime trigger to have more control over the scheduling of triggers.
In addition, this patch requires an optimization to StreamExecution that logs a commit record at the end of successfully processing a batch. This new commit log will be used to determine the next batch (offsets) to process after a restart, instead of using the offset log itself to determine what batch to process next after restart; using the offset log to determine this would process the previously logged batch, always, thus not permitting a OneTime trigger feature.
## How was this patch tested?
A number of existing tests have been revised. These tests all assumed that when restarting a stream, the last batch in the offset log is to be re-processed. Given that we now have a commit log that will tell us if that last batch was processed successfully, the results/assumptions of those tests needed to be revised accordingly.
In addition, a OneTime trigger test was added to StreamingQuerySuite, which tests:
- The semantics of OneTime trigger (i.e., on start, execute a single batch, then stop).
- The case when the commit log was not able to successfully log the completion of a batch before restart, which would mean that we should fall back to what's in the offset log.
- A OneTime trigger execution that results in an exception being thrown.
marmbrus tdas zsxwing
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Tyson Condie <tcondie@gmail.com>
Author: Tathagata Das <tathagata.das1565@gmail.com>
Closes #17219 from tcondie/stream-commit.
2017-03-23 17:32:05 -04:00
|
|
|
|
|
|
|
# Should take at least one arg
|
|
|
|
try:
|
|
|
|
df.writeStream.trigger()
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# Should not take multiple args
|
|
|
|
try:
|
|
|
|
df.writeStream.trigger(once=True, processingTime='5 seconds')
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
2018-01-18 15:25:52 -05:00
|
|
|
# Should not take multiple args
|
|
|
|
try:
|
|
|
|
df.writeStream.trigger(processingTime='5 seconds', continuous='1 second')
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
[SPARK-19876][SS][WIP] OneTime Trigger Executor
## What changes were proposed in this pull request?
An additional trigger and trigger executor that will execute a single trigger only. One can use this OneTime trigger to have more control over the scheduling of triggers.
In addition, this patch requires an optimization to StreamExecution that logs a commit record at the end of successfully processing a batch. This new commit log will be used to determine the next batch (offsets) to process after a restart, instead of using the offset log itself to determine what batch to process next after restart; using the offset log to determine this would process the previously logged batch, always, thus not permitting a OneTime trigger feature.
## How was this patch tested?
A number of existing tests have been revised. These tests all assumed that when restarting a stream, the last batch in the offset log is to be re-processed. Given that we now have a commit log that will tell us if that last batch was processed successfully, the results/assumptions of those tests needed to be revised accordingly.
In addition, a OneTime trigger test was added to StreamingQuerySuite, which tests:
- The semantics of OneTime trigger (i.e., on start, execute a single batch, then stop).
- The case when the commit log was not able to successfully log the completion of a batch before restart, which would mean that we should fall back to what's in the offset log.
- A OneTime trigger execution that results in an exception being thrown.
marmbrus tdas zsxwing
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Tyson Condie <tcondie@gmail.com>
Author: Tathagata Das <tathagata.das1565@gmail.com>
Closes #17219 from tcondie/stream-commit.
2017-03-23 17:32:05 -04:00
|
|
|
# Should take only keyword args
|
2016-04-20 13:32:01 -04:00
|
|
|
try:
|
2016-06-14 20:58:45 -04:00
|
|
|
df.writeStream.trigger('5 seconds')
|
2016-04-20 13:32:01 -04:00
|
|
|
self.fail("Should have thrown an exception")
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def test_stream_read_options(self):
|
|
|
|
schema = StructType([StructField("data", StringType(), False)])
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream\
|
|
|
|
.format('text')\
|
|
|
|
.option('path', 'python/test_support/sql/streaming')\
|
|
|
|
.schema(schema)\
|
|
|
|
.load()
|
2016-04-20 13:32:01 -04:00
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<data:string>")
|
|
|
|
|
|
|
|
def test_stream_read_options_overwrite(self):
|
|
|
|
bad_schema = StructType([StructField("test", IntegerType(), False)])
|
|
|
|
schema = StructType([StructField("data", StringType(), False)])
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \
|
|
|
|
.schema(bad_schema)\
|
|
|
|
.load(path='python/test_support/sql/streaming', schema=schema, format='text')
|
2016-04-20 13:32:01 -04:00
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<data:string>")
|
|
|
|
|
|
|
|
def test_stream_save_options(self):
|
2016-12-15 17:26:54 -05:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \
|
|
|
|
.withColumn('id', lit(1))
|
2016-06-15 13:46:02 -04:00
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
2016-04-20 13:32:01 -04:00
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
2016-06-15 13:46:02 -04:00
|
|
|
q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \
|
2016-12-15 17:26:54 -05:00
|
|
|
.format('parquet').partitionBy('id').outputMode('append').option('path', out).start()
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
self.assertEqual(q.name, 'this_query')
|
|
|
|
self.assertTrue(q.isActive)
|
|
|
|
q.processAllAvailable()
|
2016-04-28 18:22:28 -04:00
|
|
|
output_files = []
|
|
|
|
for _, _, files in os.walk(out):
|
2016-05-03 13:58:26 -04:00
|
|
|
output_files.extend([f for f in files if not f.startswith('.')])
|
2016-04-28 18:22:28 -04:00
|
|
|
self.assertTrue(len(output_files) > 0)
|
|
|
|
self.assertTrue(len(os.listdir(chk)) > 0)
|
|
|
|
finally:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
2016-04-20 13:32:01 -04:00
|
|
|
|
|
|
|
def test_stream_save_options_overwrite(self):
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
2016-06-15 13:46:02 -04:00
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
2016-04-20 13:32:01 -04:00
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
|
|
|
fake1 = os.path.join(tmpPath, 'fake1')
|
|
|
|
fake2 = os.path.join(tmpPath, 'fake2')
|
2016-06-15 13:46:02 -04:00
|
|
|
q = df.writeStream.option('checkpointLocation', fake1)\
|
2016-06-14 20:58:45 -04:00
|
|
|
.format('memory').option('path', fake2) \
|
2016-05-31 18:57:01 -04:00
|
|
|
.queryName('fake_query').outputMode('append') \
|
2016-06-14 20:58:45 -04:00
|
|
|
.start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
|
2016-05-31 18:57:01 -04:00
|
|
|
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
self.assertEqual(q.name, 'this_query')
|
|
|
|
self.assertTrue(q.isActive)
|
|
|
|
q.processAllAvailable()
|
2016-04-28 18:22:28 -04:00
|
|
|
output_files = []
|
|
|
|
for _, _, files in os.walk(out):
|
2016-05-03 13:58:26 -04:00
|
|
|
output_files.extend([f for f in files if not f.startswith('.')])
|
2016-04-28 18:22:28 -04:00
|
|
|
self.assertTrue(len(output_files) > 0)
|
|
|
|
self.assertTrue(len(os.listdir(chk)) > 0)
|
|
|
|
self.assertFalse(os.path.isdir(fake1)) # should not have been created
|
|
|
|
self.assertFalse(os.path.isdir(fake2)) # should not have been created
|
|
|
|
finally:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
2016-04-20 13:32:01 -04:00
|
|
|
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
def test_stream_status_and_progress(self):
|
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
2016-12-14 16:36:41 -05:00
|
|
|
|
|
|
|
def func(x):
|
|
|
|
time.sleep(1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
from pyspark.sql.functions import col, udf
|
|
|
|
sleep_udf = udf(func)
|
|
|
|
|
|
|
|
# Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there
|
|
|
|
# were no updates.
|
|
|
|
q = df.select(sleep_udf(col("value")).alias('value')).writeStream \
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
.start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
|
|
|
|
try:
|
2016-12-14 16:36:41 -05:00
|
|
|
# "lastProgress" will return None in most cases. However, as it may be flaky when
|
|
|
|
# Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress"
|
|
|
|
# may throw error with a high chance and make this test flaky, so we should still be
|
|
|
|
# able to detect broken codes.
|
|
|
|
q.lastProgress
|
|
|
|
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
q.processAllAvailable()
|
|
|
|
lastProgress = q.lastProgress
|
2016-12-07 18:36:29 -05:00
|
|
|
recentProgress = q.recentProgress
|
2016-11-30 02:08:56 -05:00
|
|
|
status = q.status
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
self.assertEqual(lastProgress['name'], q.name)
|
|
|
|
self.assertEqual(lastProgress['id'], q.id)
|
2016-12-07 18:36:29 -05:00
|
|
|
self.assertTrue(any(p == lastProgress for p in recentProgress))
|
2016-11-30 02:08:56 -05:00
|
|
|
self.assertTrue(
|
|
|
|
"message" in status and
|
|
|
|
"isDataAvailable" in status and
|
|
|
|
"isTriggerActive" in status)
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
finally:
|
|
|
|
q.stop()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2016-04-20 13:32:01 -04:00
|
|
|
def test_stream_await_termination(self):
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
2016-06-15 13:46:02 -04:00
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
2016-04-20 13:32:01 -04:00
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
2016-06-15 13:46:02 -04:00
|
|
|
q = df.writeStream\
|
2016-06-14 20:58:45 -04:00
|
|
|
.start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
|
2016-04-20 13:32:01 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
self.assertTrue(q.isActive)
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.awaitTermination("hello")
|
2016-04-28 18:22:28 -04:00
|
|
|
self.fail("Expected a value exception")
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
now = time.time()
|
|
|
|
# test should take at least 2 seconds
|
2016-06-15 13:46:02 -04:00
|
|
|
res = q.awaitTermination(2.6)
|
2016-04-28 18:22:28 -04:00
|
|
|
duration = time.time() - now
|
|
|
|
self.assertTrue(duration >= 2)
|
|
|
|
self.assertFalse(res)
|
|
|
|
finally:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2016-12-05 14:36:11 -05:00
|
|
|
def test_stream_exception(self):
|
|
|
|
sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
|
|
|
sq = sdf.writeStream.format('memory').queryName('query_explain').start()
|
|
|
|
try:
|
|
|
|
sq.processAllAvailable()
|
|
|
|
self.assertEqual(sq.exception(), None)
|
|
|
|
finally:
|
|
|
|
sq.stop()
|
|
|
|
|
|
|
|
from pyspark.sql.functions import col, udf
|
|
|
|
from pyspark.sql.utils import StreamingQueryException
|
|
|
|
bad_udf = udf(lambda x: 1 / 0)
|
|
|
|
sq = sdf.select(bad_udf(col("value")))\
|
|
|
|
.writeStream\
|
|
|
|
.format('memory')\
|
|
|
|
.queryName('this_query')\
|
|
|
|
.start()
|
|
|
|
try:
|
|
|
|
# Process some data to fail the query
|
|
|
|
sq.processAllAvailable()
|
|
|
|
self.fail("bad udf should fail the query")
|
|
|
|
except StreamingQueryException as e:
|
|
|
|
# This is expected
|
|
|
|
self.assertTrue("ZeroDivisionError" in e.desc)
|
|
|
|
finally:
|
|
|
|
sq.stop()
|
|
|
|
self.assertTrue(type(sq.exception()) is StreamingQueryException)
|
|
|
|
self.assertTrue("ZeroDivisionError" in sq.exception().desc)
|
|
|
|
|
2016-04-28 18:22:28 -04:00
|
|
|
def test_query_manager_await_termination(self):
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
2016-06-15 13:46:02 -04:00
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
tmpPath = tempfile.mkdtemp()
|
2016-04-20 13:32:01 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
2016-04-28 18:22:28 -04:00
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
2016-06-15 13:46:02 -04:00
|
|
|
q = df.writeStream\
|
2016-06-14 20:58:45 -04:00
|
|
|
.start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
self.assertTrue(q.isActive)
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark._wrapped.streams.awaitAnyTermination("hello")
|
2016-04-28 18:22:28 -04:00
|
|
|
self.fail("Expected a value exception")
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
now = time.time()
|
|
|
|
# test should take at least 2 seconds
|
2016-05-11 14:24:16 -04:00
|
|
|
res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
|
2016-04-28 18:22:28 -04:00
|
|
|
duration = time.time() - now
|
|
|
|
self.assertTrue(duration >= 2)
|
|
|
|
self.assertFalse(res)
|
|
|
|
finally:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
2016-04-20 13:32:01 -04:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_help_command(self):
|
|
|
|
# Regression test for SPARK-5464
|
|
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.read.json(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
# render_doc() reproduces the help() exception without printing output
|
|
|
|
pydoc.render_doc(df)
|
|
|
|
pydoc.render_doc(df.foo)
|
|
|
|
pydoc.render_doc(df.take(1))
|
|
|
|
|
2015-04-16 20:33:57 -04:00
|
|
|
def test_access_column(self):
|
|
|
|
df = self.df
|
|
|
|
self.assertTrue(isinstance(df.key, Column))
|
|
|
|
self.assertTrue(isinstance(df['key'], Column))
|
|
|
|
self.assertTrue(isinstance(df[0], Column))
|
|
|
|
self.assertRaises(IndexError, lambda: df[2])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertRaises(AnalysisException, lambda: df["bad_key"])
|
2015-04-16 20:33:57 -04:00
|
|
|
self.assertRaises(TypeError, lambda: df[{}])
|
|
|
|
|
2015-07-01 19:43:18 -04:00
|
|
|
def test_column_name_with_non_ascii(self):
|
2016-05-18 14:18:33 -04:00
|
|
|
if sys.version >= '3':
|
|
|
|
columnName = "数量"
|
|
|
|
self.assertTrue(isinstance(columnName, str))
|
|
|
|
else:
|
|
|
|
columnName = unicode("数量", "utf-8")
|
|
|
|
self.assertTrue(isinstance(columnName, unicode))
|
|
|
|
schema = StructType([StructField(columnName, LongType(), True)])
|
|
|
|
df = self.spark.createDataFrame([(1,)], schema)
|
|
|
|
self.assertEqual(schema, df.schema)
|
2015-07-01 19:43:18 -04:00
|
|
|
self.assertEqual("DataFrame[数量: bigint]", str(df))
|
|
|
|
self.assertEqual([("数量", 'bigint')], df.dtypes)
|
|
|
|
self.assertEqual(1, df.select("数量").first()[0])
|
|
|
|
self.assertEqual(1, df.select(df["数量"]).first()[0])
|
|
|
|
|
2015-04-16 20:33:57 -04:00
|
|
|
def test_access_nested_types(self):
|
|
|
|
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
|
|
|
|
self.assertEqual(1, df.select(df.l[0]).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.r.a).first()[0])
|
|
|
|
self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
|
|
|
|
self.assertEqual("v", df.select(df.d["k"]).first()[0])
|
|
|
|
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
|
|
|
|
|
2015-05-08 14:49:38 -04:00
|
|
|
def test_field_accessor(self):
|
|
|
|
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
|
|
|
|
self.assertEqual(1, df.select(df.l[0]).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.r["a"]).first()[0])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertEqual(1, df.select(df["r.a"]).first()[0])
|
2015-05-08 14:49:38 -04:00
|
|
|
self.assertEqual("b", df.select(df.r["b"]).first()[0])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertEqual("b", df.select(df["r.b"]).first()[0])
|
2015-05-08 14:49:38 -04:00
|
|
|
self.assertEqual("v", df.select(df.d["k"]).first()[0])
|
|
|
|
|
2015-02-18 17:17:04 -05:00
|
|
|
def test_infer_long_type(self):
|
|
|
|
longrow = [Row(f1='a', f2=100000000000000)]
|
|
|
|
df = self.sc.parallelize(longrow).toDF()
|
|
|
|
self.assertEqual(df.schema.fields[1].dataType, LongType())
|
|
|
|
|
|
|
|
# this saving as Parquet caused issues as well.
|
|
|
|
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
|
2016-01-04 21:02:38 -05:00
|
|
|
df.write.parquet(output_dir)
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.read.parquet(output_dir)
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual('a', df1.first().f1)
|
|
|
|
self.assertEqual(100000000000000, df1.first().f2)
|
2015-02-18 17:17:04 -05:00
|
|
|
|
|
|
|
self.assertEqual(_infer_type(1), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**10), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**20), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**31 - 1), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**31), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**61), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**71), LongType())
|
|
|
|
|
2018-01-08 00:32:05 -05:00
|
|
|
def test_merge_type(self):
|
|
|
|
self.assertEqual(_merge_type(LongType(), NullType()), LongType())
|
|
|
|
self.assertEqual(_merge_type(NullType(), LongType()), LongType())
|
|
|
|
|
|
|
|
self.assertEqual(_merge_type(LongType(), LongType()), LongType())
|
|
|
|
|
|
|
|
self.assertEqual(_merge_type(
|
|
|
|
ArrayType(LongType()),
|
|
|
|
ArrayType(LongType())
|
|
|
|
), ArrayType(LongType()))
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'element in array'):
|
|
|
|
_merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
|
|
|
|
|
|
|
|
self.assertEqual(_merge_type(
|
|
|
|
MapType(StringType(), LongType()),
|
|
|
|
MapType(StringType(), LongType())
|
|
|
|
), MapType(StringType(), LongType()))
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'key of map'):
|
|
|
|
_merge_type(
|
|
|
|
MapType(StringType(), LongType()),
|
|
|
|
MapType(DoubleType(), LongType()))
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'value of map'):
|
|
|
|
_merge_type(
|
|
|
|
MapType(StringType(), LongType()),
|
|
|
|
MapType(StringType(), DoubleType()))
|
|
|
|
|
|
|
|
self.assertEqual(_merge_type(
|
|
|
|
StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
|
|
|
|
StructType([StructField("f1", LongType()), StructField("f2", StringType())])
|
|
|
|
), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'field f1'):
|
|
|
|
_merge_type(
|
|
|
|
StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
|
|
|
|
StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))
|
|
|
|
|
|
|
|
self.assertEqual(_merge_type(
|
|
|
|
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
|
|
|
|
StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
|
|
|
|
), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
|
|
|
|
_merge_type(
|
|
|
|
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
|
|
|
|
StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))
|
|
|
|
|
|
|
|
self.assertEqual(_merge_type(
|
|
|
|
StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
|
|
|
|
StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
|
|
|
|
), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
|
|
|
|
_merge_type(
|
|
|
|
StructType([
|
|
|
|
StructField("f1", ArrayType(LongType())),
|
|
|
|
StructField("f2", StringType())]),
|
|
|
|
StructType([
|
|
|
|
StructField("f1", ArrayType(DoubleType())),
|
|
|
|
StructField("f2", StringType())]))
|
|
|
|
|
|
|
|
self.assertEqual(_merge_type(
|
|
|
|
StructType([
|
|
|
|
StructField("f1", MapType(StringType(), LongType())),
|
|
|
|
StructField("f2", StringType())]),
|
|
|
|
StructType([
|
|
|
|
StructField("f1", MapType(StringType(), LongType())),
|
|
|
|
StructField("f2", StringType())])
|
|
|
|
), StructType([
|
|
|
|
StructField("f1", MapType(StringType(), LongType())),
|
|
|
|
StructField("f2", StringType())]))
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
|
|
|
|
_merge_type(
|
|
|
|
StructType([
|
|
|
|
StructField("f1", MapType(StringType(), LongType())),
|
|
|
|
StructField("f2", StringType())]),
|
|
|
|
StructType([
|
|
|
|
StructField("f1", MapType(StringType(), DoubleType())),
|
|
|
|
StructField("f2", StringType())]))
|
|
|
|
|
|
|
|
self.assertEqual(_merge_type(
|
|
|
|
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
|
|
|
|
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
|
|
|
|
), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
|
|
|
|
_merge_type(
|
|
|
|
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
|
|
|
|
StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
|
|
|
|
)
|
|
|
|
|
2015-04-21 03:08:18 -04:00
|
|
|
def test_filter_with_datetime(self):
|
|
|
|
time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
|
|
|
|
date = time.date()
|
|
|
|
row = Row(date=date, time=time)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-04-21 03:08:18 -04:00
|
|
|
self.assertEqual(1, df.filter(df.date == date).count())
|
|
|
|
self.assertEqual(1, df.filter(df.time == time).count())
|
|
|
|
self.assertEqual(0, df.filter(df.date > date).count())
|
|
|
|
self.assertEqual(0, df.filter(df.time > time).count())
|
|
|
|
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
def test_filter_with_datetime_timezone(self):
|
|
|
|
dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
|
|
|
|
dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
|
|
|
|
row = Row(date=dt1)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
self.assertEqual(0, df.filter(df.date == dt2).count())
|
|
|
|
self.assertEqual(1, df.filter(df.date > dt2).count())
|
|
|
|
self.assertEqual(0, df.filter(df.date < dt2).count())
|
|
|
|
|
2015-06-11 04:00:41 -04:00
|
|
|
def test_time_with_timezone(self):
|
|
|
|
day = datetime.date.today()
|
|
|
|
now = datetime.datetime.now()
|
2015-07-10 16:05:23 -04:00
|
|
|
ts = time.mktime(now.timetuple())
|
2015-06-11 04:00:41 -04:00
|
|
|
# class in __main__ is not serializable
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
from pyspark.sql.tests import UTCOffsetTimezone
|
|
|
|
utc = UTCOffsetTimezone()
|
2015-07-10 16:05:23 -04:00
|
|
|
utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds
|
2015-07-10 20:44:21 -04:00
|
|
|
# add microseconds to utcnow (keeping year,month,day,hour,minute,second)
|
2015-07-10 16:05:23 -04:00
|
|
|
utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(day, now, utcnow)])
|
2015-06-11 04:00:41 -04:00
|
|
|
day1, now1, utcnow1 = df.first()
|
2015-07-09 17:43:38 -04:00
|
|
|
self.assertEqual(day1, day)
|
|
|
|
self.assertEqual(now, now1)
|
|
|
|
self.assertEqual(now, utcnow1)
|
2015-06-11 04:00:41 -04:00
|
|
|
|
2017-03-09 13:34:54 -05:00
|
|
|
# regression test for SPARK-19561
|
|
|
|
def test_datetime_at_epoch(self):
|
|
|
|
epoch = datetime.datetime.fromtimestamp(0)
|
|
|
|
df = self.spark.createDataFrame([Row(date=epoch)])
|
|
|
|
first = df.select('date', lit(epoch).alias('lit_date')).first()
|
|
|
|
self.assertEqual(first['date'], epoch)
|
|
|
|
self.assertEqual(first['lit_date'], epoch)
|
|
|
|
|
2017-11-09 00:44:39 -05:00
|
|
|
def test_dayofweek(self):
|
|
|
|
from pyspark.sql.functions import dayofweek
|
|
|
|
dt = datetime.datetime(2017, 11, 6)
|
|
|
|
df = self.spark.createDataFrame([Row(date=dt)])
|
|
|
|
row = df.select(dayofweek(df.date)).first()
|
|
|
|
self.assertEqual(row[0], 2)
|
|
|
|
|
2015-07-08 21:22:53 -04:00
|
|
|
def test_decimal(self):
|
|
|
|
from decimal import Decimal
|
|
|
|
schema = StructType([StructField("decimal", DecimalType(10, 5))])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema)
|
2015-07-08 21:22:53 -04:00
|
|
|
row = df.select(df.decimal + 1).first()
|
|
|
|
self.assertEqual(row[0], Decimal("4.14159"))
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
df.write.parquet(tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
df2 = self.spark.read.parquet(tmpPath)
|
2015-07-08 21:22:53 -04:00
|
|
|
row = df2.first()
|
|
|
|
self.assertEqual(row[0], Decimal("3.14159"))
|
|
|
|
|
2015-03-30 23:47:10 -04:00
|
|
|
def test_dropna(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
# shouldn't drop a non-null row
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', 50, 80.1)], schema).dropna().count(),
|
|
|
|
1)
|
|
|
|
|
|
|
|
# dropping rows with a single null value
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 80.1)], schema).dropna().count(),
|
|
|
|
0)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# if how = 'all', only drop rows if all values are null
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
|
|
|
|
1)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(None, None, None)], schema).dropna(how='all').count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# how and subset
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
|
|
|
|
1)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# threshold
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
|
|
|
|
1)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, None)], schema).dropna(thresh=2).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# threshold and subset
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
|
|
|
|
1)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# thresh should take precedence over how
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', 50, None)], schema).dropna(
|
|
|
|
how='any', thresh=2, subset=['name', 'age']).count(),
|
|
|
|
1)
|
|
|
|
|
|
|
|
def test_fillna(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
2017-06-03 01:56:42 -04:00
|
|
|
StructField("height", DoubleType(), True),
|
|
|
|
StructField("spy", BooleanType(), True)])
|
2015-03-30 23:47:10 -04:00
|
|
|
|
|
|
|
# fillna shouldn't change non-null values
|
2017-06-03 01:56:42 -04:00
|
|
|
row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# fillna with int
|
2017-06-03 01:56:42 -04:00
|
|
|
row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, 50.0)
|
|
|
|
|
|
|
|
# fillna with double
|
2017-06-03 01:56:42 -04:00
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', None, None, None)], schema).fillna(50.1).first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, 50.1)
|
|
|
|
|
2017-06-03 01:56:42 -04:00
|
|
|
# fillna with bool
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', None, None, None)], schema).fillna(True).first()
|
|
|
|
self.assertEqual(row.age, None)
|
|
|
|
self.assertEqual(row.spy, True)
|
|
|
|
|
2015-03-30 23:47:10 -04:00
|
|
|
# fillna with string
|
2017-06-03 01:56:42 -04:00
|
|
|
row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.name, u"hello")
|
|
|
|
self.assertEqual(row.age, None)
|
|
|
|
|
|
|
|
# fillna with subset specified for numeric cols
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2017-06-03 01:56:42 -04:00
|
|
|
[(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.name, None)
|
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, None)
|
2017-06-03 01:56:42 -04:00
|
|
|
self.assertEqual(row.spy, None)
|
2015-03-30 23:47:10 -04:00
|
|
|
|
2017-06-03 01:56:42 -04:00
|
|
|
# fillna with subset specified for string cols
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2017-06-03 01:56:42 -04:00
|
|
|
[(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.name, "haha")
|
|
|
|
self.assertEqual(row.age, None)
|
|
|
|
self.assertEqual(row.height, None)
|
2017-06-03 01:56:42 -04:00
|
|
|
self.assertEqual(row.spy, None)
|
|
|
|
|
|
|
|
# fillna with subset specified for bool cols
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first()
|
|
|
|
self.assertEqual(row.name, None)
|
|
|
|
self.assertEqual(row.age, None)
|
|
|
|
self.assertEqual(row.height, None)
|
|
|
|
self.assertEqual(row.spy, True)
|
2015-03-30 23:47:10 -04:00
|
|
|
|
2017-05-01 00:42:05 -04:00
|
|
|
# fillna with dictionary for boolean types
|
|
|
|
row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
|
|
|
|
self.assertEqual(row.a, True)
|
|
|
|
|
2015-05-07 04:00:29 -04:00
|
|
|
def test_bitwise_operations(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
row = Row(a=170, b=75)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-05-07 04:00:29 -04:00
|
|
|
result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 & 75, result['(a & b)'])
|
|
|
|
result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 | 75, result['(a | b)'])
|
|
|
|
result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 ^ 75, result['(a ^ b)'])
|
|
|
|
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(~75, result['~b'])
|
|
|
|
|
2015-07-25 03:34:59 -04:00
|
|
|
def test_expr(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
row = Row(a="length string", b=75)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-07-25 03:34:59 -04:00
|
|
|
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
|
2015-11-10 14:06:29 -05:00
|
|
|
self.assertEqual(13, result["length(a)"])
|
2015-07-25 03:34:59 -04:00
|
|
|
|
2018-02-11 05:23:15 -05:00
|
|
|
def test_repartitionByRange_dataframe(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
df1 = self.spark.createDataFrame(
|
|
|
|
[(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema)
|
|
|
|
df2 = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema)
|
|
|
|
|
|
|
|
# test repartitionByRange(numPartitions, *cols)
|
|
|
|
df3 = df1.repartitionByRange(2, "name", "age")
|
|
|
|
self.assertEqual(df3.rdd.getNumPartitions(), 2)
|
|
|
|
self.assertEqual(df3.rdd.first(), df2.rdd.first())
|
|
|
|
self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))
|
|
|
|
|
|
|
|
# test repartitionByRange(numPartitions, *cols)
|
|
|
|
df4 = df1.repartitionByRange(3, "name", "age")
|
|
|
|
self.assertEqual(df4.rdd.getNumPartitions(), 3)
|
|
|
|
self.assertEqual(df4.rdd.first(), df2.rdd.first())
|
|
|
|
self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))
|
|
|
|
|
|
|
|
# test repartitionByRange(*cols)
|
|
|
|
df5 = df1.repartitionByRange("name", "age")
|
|
|
|
self.assertEqual(df5.rdd.first(), df2.rdd.first())
|
|
|
|
self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))
|
|
|
|
|
2015-05-12 13:23:41 -04:00
|
|
|
def test_replace(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
# replace with int
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
|
2015-05-12 13:23:41 -04:00
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
self.assertEqual(row.height, 20.0)
|
|
|
|
|
|
|
|
# replace with double
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
|
|
|
|
self.assertEqual(row.age, 82)
|
|
|
|
self.assertEqual(row.height, 82.1)
|
|
|
|
|
|
|
|
# replace with string
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
|
|
|
|
self.assertEqual(row.name, u"Ann")
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# replace with subset specified by a string of a column name w/ actual change
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
|
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
|
|
|
|
# replace with subset specified by a string of a column name w/o actual change
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# replace with subset specified with one column replaced, another column not in subset
|
|
|
|
# stays unchanged.
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
|
|
|
|
self.assertEqual(row.name, u'Alice')
|
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
self.assertEqual(row.height, 10.0)
|
|
|
|
|
|
|
|
# replace with subset specified but no column will be replaced
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
|
|
|
|
self.assertEqual(row.name, u'Alice')
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
self.assertEqual(row.height, None)
|
|
|
|
|
2017-04-05 14:47:40 -04:00
|
|
|
# replace with lists
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
|
|
|
|
self.assertTupleEqual(row, (u'Ann', 10, 80.1))
|
|
|
|
|
|
|
|
# replace with dict
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 11, 80.1))
|
|
|
|
|
|
|
|
# test backward compatibility with dummy value
|
|
|
|
dummy_value = 1
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
|
|
|
|
self.assertTupleEqual(row, (u'Bob', 10, 80.1))
|
|
|
|
|
|
|
|
# test dict with mixed numerics
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', -10, 90.5))
|
|
|
|
|
|
|
|
# replace with tuples
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
|
|
|
|
self.assertTupleEqual(row, (u'Bob', 10, 80.1))
|
|
|
|
|
|
|
|
# replace multiple columns
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.0))
|
|
|
|
|
|
|
|
# test for mixed numerics
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.5))
|
|
|
|
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.5))
|
|
|
|
|
|
|
|
# replace with boolean
|
|
|
|
row = (self
|
|
|
|
.spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
|
|
|
|
.selectExpr("name = 'Bob'", 'age <= 15')
|
|
|
|
.replace(False, True).first())
|
|
|
|
self.assertTupleEqual(row, (True, True))
|
|
|
|
|
[SPARK-14932][SQL] Allow DataFrame.replace() to replace values with None
## What changes were proposed in this pull request?
Currently `df.na.replace("*", Map[String, String]("NULL" -> null))` will produce exception.
This PR enables passing null/None as value in the replacement map in DataFrame.replace().
Note that the replacement map keys and values should still be the same type, while the values can have a mix of null/None and that type.
This PR enables following operations for example:
`df.na.replace("*", Map[String, String]("NULL" -> null))`(scala)
`df.na.replace("*", Map[Any, Any](60 -> null, 70 -> 80))`(scala)
`df.na.replace('Alice', None)`(python)
`df.na.replace([10, 20])`(python, replacing with None is by default)
One use case could be: I want to replace all the empty strings with null/None because they were incorrectly generated and then drop all null/None data
`df.na.replace("*", Map("" -> null)).na.drop()`(scala)
`df.replace(u'', None).dropna()`(python)
## How was this patch tested?
Scala unit test.
Python doctest and unit test.
Author: bravo-zhang <mzhang1230@gmail.com>
Closes #18820 from bravo-zhang/spark-14932.
2017-08-09 20:42:21 -04:00
|
|
|
# replace string with None and then drop None rows
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
|
|
|
|
self.assertEqual(row.count(), 0)
|
|
|
|
|
|
|
|
# replace with number and None
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 20, None))
|
|
|
|
|
2017-04-05 14:47:40 -04:00
|
|
|
# should fail if subset is not list, tuple or None
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()
|
|
|
|
|
|
|
|
# should fail if to_replace and value have different length
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()
|
|
|
|
|
|
|
|
# should fail if when received unexpected type
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
from datetime import datetime
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()
|
|
|
|
|
|
|
|
# should fail if provided mixed type replacements
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
|
|
|
|
|
2018-02-09 01:21:10 -05:00
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
TypeError,
|
|
|
|
'value argument is required when to_replace is not a dictionary.'):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
|
|
|
|
|
2015-06-30 19:17:46 -04:00
|
|
|
def test_capture_analysis_exception(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
|
2015-06-30 19:17:46 -04:00
|
|
|
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
|
2016-03-28 15:31:12 -04:00
|
|
|
|
|
|
|
def test_capture_parse_exception(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
|
2015-06-30 19:17:46 -04:00
|
|
|
|
2015-07-19 03:32:56 -04:00
|
|
|
def test_capture_illegalargument_exception(self):
|
|
|
|
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
|
2016-05-11 14:24:16 -04:00
|
|
|
lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
|
|
|
|
df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
|
2015-07-19 03:32:56 -04:00
|
|
|
self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
|
|
|
|
lambda: df.select(sha2(df.a, 1024)).collect())
|
2015-10-29 00:45:00 -04:00
|
|
|
try:
|
|
|
|
df.select(sha2(df.a, 1024)).collect()
|
|
|
|
except IllegalArgumentException as e:
|
|
|
|
self.assertRegexpMatches(e.desc, "1024 is not in the permitted values")
|
|
|
|
self.assertRegexpMatches(e.stackTrace,
|
|
|
|
"org.apache.spark.sql.functions")
|
2015-07-19 03:32:56 -04:00
|
|
|
|
2015-08-19 16:56:40 -04:00
|
|
|
def test_with_column_with_existing_name(self):
|
|
|
|
keys = self.df.withColumn("key", self.df.key).select("key").collect()
|
|
|
|
self.assertEqual([r.key for r in keys], list(range(100)))
|
|
|
|
|
2015-09-02 16:36:36 -04:00
|
|
|
# regression test for SPARK-10417
|
|
|
|
def test_column_iterator(self):
|
|
|
|
|
|
|
|
def foo():
|
|
|
|
for x in self.df.key:
|
|
|
|
break
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, foo)
|
|
|
|
|
2015-09-22 02:36:41 -04:00
|
|
|
# add test for SPARK-10577 (test broadcast join hint)
|
|
|
|
def test_functions_broadcast(self):
|
|
|
|
from pyspark.sql.functions import broadcast
|
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
|
|
|
|
df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
|
2015-09-22 02:36:41 -04:00
|
|
|
|
|
|
|
# equijoin - should be converted into broadcast join
|
|
|
|
plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
|
|
|
|
self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
|
|
|
|
|
|
|
|
# no join key -- should not be a broadcast join
|
2016-10-14 21:24:47 -04:00
|
|
|
plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
|
2015-09-22 02:36:41 -04:00
|
|
|
self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
|
|
|
|
|
|
|
|
# planner should not crash without a join
|
|
|
|
broadcast(df1)._jdf.queryExecution().executedPlan()
|
|
|
|
|
2017-05-03 22:15:28 -04:00
|
|
|
def test_generic_hints(self):
|
|
|
|
from pyspark.sql import DataFrame
|
|
|
|
|
|
|
|
df1 = self.spark.range(10e10).toDF("id")
|
|
|
|
df2 = self.spark.range(10e10).toDF("id")
|
|
|
|
|
|
|
|
self.assertIsInstance(df1.hint("broadcast"), DataFrame)
|
|
|
|
self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
|
|
|
|
|
|
|
|
# Dummy rules
|
|
|
|
self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
|
|
|
|
self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
|
|
|
|
|
|
|
|
plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
|
|
|
|
self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
|
|
|
|
|
[SPARK-21779][PYTHON] Simpler DataFrame.sample API in Python
## What changes were proposed in this pull request?
This PR make `DataFrame.sample(...)` can omit `withReplacement` defaulting `False`, consistently with equivalent Scala / Java API.
In short, the following examples are allowed:
```python
>>> df = spark.range(10)
>>> df.sample(0.5).count()
7
>>> df.sample(fraction=0.5).count()
3
>>> df.sample(0.5, seed=42).count()
5
>>> df.sample(fraction=0.5, seed=42).count()
5
```
In addition, this PR also adds some type checking logics as below:
```python
>>> df = spark.range(10)
>>> df.sample().count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [].
>>> df.sample(True).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>].
>>> df.sample(42).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'int'>].
>>> df.sample(fraction=False, seed="a").count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>, <type 'str'>].
>>> df.sample(seed=[1]).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'list'>].
>>> df.sample(withReplacement="a", fraction=0.5, seed=1)
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'str'>, <type 'float'>, <type 'int'>].
```
## How was this patch tested?
Manually tested, unit tests added in doc tests and manually checked the built documentation for Python.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #18999 from HyukjinKwon/SPARK-21779.
2017-09-01 00:01:23 -04:00
|
|
|
def test_sample(self):
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
TypeError,
|
|
|
|
"should be a bool, float and number",
|
|
|
|
lambda: self.spark.range(1).sample())
|
|
|
|
|
|
|
|
self.assertRaises(
|
|
|
|
TypeError,
|
|
|
|
lambda: self.spark.range(1).sample("a"))
|
|
|
|
|
|
|
|
self.assertRaises(
|
|
|
|
TypeError,
|
|
|
|
lambda: self.spark.range(1).sample(seed="abc"))
|
|
|
|
|
|
|
|
self.assertRaises(
|
|
|
|
IllegalArgumentException,
|
|
|
|
lambda: self.spark.range(1).sample(-1.0))
|
|
|
|
|
2016-03-08 17:00:03 -05:00
|
|
|
def test_toDF_with_schema_string(self):
|
|
|
|
data = [Row(key=i, value=str(i)) for i in range(100)]
|
|
|
|
rdd = self.sc.parallelize(data, 5)
|
|
|
|
|
|
|
|
df = rdd.toDF("key: int, value: string")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
|
|
|
|
self.assertEqual(df.collect(), data)
|
|
|
|
|
|
|
|
# different but compatible field types can be used.
|
|
|
|
df = rdd.toDF("key: string, value: string")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
|
|
|
|
|
|
|
|
# field names can differ.
|
|
|
|
df = rdd.toDF(" a: int, b: string ")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
|
|
|
|
self.assertEqual(df.collect(), data)
|
|
|
|
|
|
|
|
# number of fields must match.
|
|
|
|
self.assertRaisesRegexp(Exception, "Length of object",
|
|
|
|
lambda: rdd.toDF("key: int").collect())
|
|
|
|
|
|
|
|
# field types mismatch will cause exception at runtime.
|
|
|
|
self.assertRaisesRegexp(Exception, "FloatType can not accept",
|
|
|
|
lambda: rdd.toDF("key: float, value: string").collect())
|
|
|
|
|
|
|
|
# flat schema values will be wrapped into row.
|
|
|
|
df = rdd.map(lambda row: row.key).toDF("int")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
|
|
|
|
|
|
|
|
# users can use DataType directly instead of data type string.
|
|
|
|
df = rdd.map(lambda row: row.key).toDF(IntegerType())
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
|
|
|
|
|
[SPARK-21264][PYTHON] Call cross join path in join without 'on' and with 'how'
## What changes were proposed in this pull request?
Currently, it throws a NPE when missing columns but join type is speicified in join at PySpark as below:
```python
spark.conf.set("spark.sql.crossJoin.enabled", "false")
spark.range(1).join(spark.range(1), how="inner").show()
```
```
Traceback (most recent call last):
...
py4j.protocol.Py4JJavaError: An error occurred while calling o66.join.
: java.lang.NullPointerException
at org.apache.spark.sql.Dataset.join(Dataset.scala:931)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
...
```
```python
spark.conf.set("spark.sql.crossJoin.enabled", "true")
spark.range(1).join(spark.range(1), how="inner").show()
```
```
...
py4j.protocol.Py4JJavaError: An error occurred while calling o84.join.
: java.lang.NullPointerException
at org.apache.spark.sql.Dataset.join(Dataset.scala:931)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
...
```
This PR suggests to follow Scala's one as below:
```scala
scala> spark.conf.set("spark.sql.crossJoin.enabled", "false")
scala> spark.range(1).join(spark.range(1), Seq.empty[String], "inner").show()
```
```
org.apache.spark.sql.AnalysisException: Detected cartesian product for INNER join between logical plans
Range (0, 1, step=1, splits=Some(8))
and
Range (0, 1, step=1, splits=Some(8))
Join condition is missing or trivial.
Use the CROSS JOIN syntax to allow cartesian products between these relations.;
...
```
```scala
scala> spark.conf.set("spark.sql.crossJoin.enabled", "true")
scala> spark.range(1).join(spark.range(1), Seq.empty[String], "inner").show()
```
```
+---+---+
| id| id|
+---+---+
| 0| 0|
+---+---+
```
**After**
```python
spark.conf.set("spark.sql.crossJoin.enabled", "false")
spark.range(1).join(spark.range(1), how="inner").show()
```
```
Traceback (most recent call last):
...
pyspark.sql.utils.AnalysisException: u'Detected cartesian product for INNER join between logical plans\nRange (0, 1, step=1, splits=Some(8))\nand\nRange (0, 1, step=1, splits=Some(8))\nJoin condition is missing or trivial.\nUse the CROSS JOIN syntax to allow cartesian products between these relations.;'
```
```python
spark.conf.set("spark.sql.crossJoin.enabled", "true")
spark.range(1).join(spark.range(1), how="inner").show()
```
```
+---+---+
| id| id|
+---+---+
| 0| 0|
+---+---+
```
## How was this patch tested?
Added tests in `python/pyspark/sql/tests.py`.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #18484 from HyukjinKwon/SPARK-21264.
2017-07-03 22:35:08 -04:00
|
|
|
def test_join_without_on(self):
|
|
|
|
df1 = self.spark.range(1).toDF("a")
|
|
|
|
df2 = self.spark.range(1).toDF("b")
|
|
|
|
|
|
|
|
try:
|
|
|
|
self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
|
|
|
|
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
|
|
|
|
|
|
|
|
self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
|
|
|
|
actual = df1.join(df2, how="inner").collect()
|
|
|
|
expected = [Row(a=0, b=0)]
|
|
|
|
self.assertEqual(actual, expected)
|
|
|
|
finally:
|
|
|
|
# We should unset this. Otherwise, other tests are affected.
|
|
|
|
self.spark.conf.unset("spark.sql.crossJoin.enabled")
|
|
|
|
|
2016-10-12 13:09:49 -04:00
|
|
|
# Regression test for invalid join methods when on is None, Spark-14761
|
|
|
|
def test_invalid_join_method(self):
|
|
|
|
df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])
|
|
|
|
df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"])
|
|
|
|
self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type"))
|
|
|
|
|
2016-10-14 21:24:47 -04:00
|
|
|
# Cartesian products require cross join syntax
|
|
|
|
def test_require_cross(self):
|
|
|
|
from pyspark.sql.functions import broadcast
|
|
|
|
|
|
|
|
df1 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
|
|
|
|
df2 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
|
|
|
|
|
|
|
|
# joins without conditions require cross join syntax
|
|
|
|
self.assertRaises(AnalysisException, lambda: df1.join(df2).collect())
|
|
|
|
|
|
|
|
# works with crossJoin
|
|
|
|
self.assertEqual(1, df1.crossJoin(df2).count())
|
|
|
|
|
2016-04-29 19:41:13 -04:00
|
|
|
def test_conf(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 23:46:07 -04:00
|
|
|
spark.conf.set("bogo", "sipeo")
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(spark.conf.get("bogo"), "sipeo")
|
2016-04-29 23:46:07 -04:00
|
|
|
spark.conf.set("bogo", "ta")
|
2016-04-29 19:41:13 -04:00
|
|
|
self.assertEqual(spark.conf.get("bogo"), "ta")
|
|
|
|
self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
|
|
|
|
self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
|
|
|
|
self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set"))
|
|
|
|
spark.conf.unset("bogo")
|
|
|
|
self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
|
|
|
|
|
|
|
|
def test_current_database(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
self.assertEquals(spark.catalog.currentDatabase(), "default")
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
|
|
|
spark.catalog.setCurrentDatabase("some_db")
|
|
|
|
self.assertEquals(spark.catalog.currentDatabase(), "some_db")
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
|
|
|
|
|
|
|
|
def test_list_databases(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
databases = [db.name for db in spark.catalog.listDatabases()]
|
|
|
|
self.assertEquals(databases, ["default"])
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
|
|
|
databases = [db.name for db in spark.catalog.listDatabases()]
|
|
|
|
self.assertEquals(sorted(databases), ["default", "some_db"])
|
|
|
|
|
|
|
|
def test_list_tables(self):
|
|
|
|
from pyspark.sql.catalog import Table
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
|
|
|
self.assertEquals(spark.catalog.listTables(), [])
|
|
|
|
self.assertEquals(spark.catalog.listTables("some_db"), [])
|
2016-05-17 21:01:59 -04:00
|
|
|
spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
|
2017-01-22 23:37:37 -05:00
|
|
|
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
|
|
|
|
spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
|
2016-04-29 19:41:13 -04:00
|
|
|
tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
|
|
|
|
tablesDefault = sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
|
|
|
|
tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
|
|
|
|
self.assertEquals(tables, tablesDefault)
|
|
|
|
self.assertEquals(len(tables), 2)
|
|
|
|
self.assertEquals(len(tablesSomeDb), 2)
|
|
|
|
self.assertEquals(tables[0], Table(
|
|
|
|
name="tab1",
|
|
|
|
database="default",
|
|
|
|
description=None,
|
|
|
|
tableType="MANAGED",
|
|
|
|
isTemporary=False))
|
|
|
|
self.assertEquals(tables[1], Table(
|
|
|
|
name="temp_tab",
|
|
|
|
database=None,
|
|
|
|
description=None,
|
|
|
|
tableType="TEMPORARY",
|
|
|
|
isTemporary=True))
|
|
|
|
self.assertEquals(tablesSomeDb[0], Table(
|
|
|
|
name="tab2",
|
|
|
|
database="some_db",
|
|
|
|
description=None,
|
|
|
|
tableType="MANAGED",
|
|
|
|
isTemporary=False))
|
|
|
|
self.assertEquals(tablesSomeDb[1], Table(
|
|
|
|
name="temp_tab",
|
|
|
|
database=None,
|
|
|
|
description=None,
|
|
|
|
tableType="TEMPORARY",
|
|
|
|
isTemporary=True))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.listTables("does_not_exist"))
|
|
|
|
|
|
|
|
def test_list_functions(self):
|
|
|
|
from pyspark.sql.catalog import Function
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
|
|
|
functions = dict((f.name, f) for f in spark.catalog.listFunctions())
|
|
|
|
functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
|
2016-06-27 14:50:34 -04:00
|
|
|
self.assertTrue(len(functions) > 200)
|
|
|
|
self.assertTrue("+" in functions)
|
|
|
|
self.assertTrue("like" in functions)
|
|
|
|
self.assertTrue("month" in functions)
|
2017-02-07 09:50:30 -05:00
|
|
|
self.assertTrue("to_date" in functions)
|
|
|
|
self.assertTrue("to_timestamp" in functions)
|
2016-06-27 14:50:34 -04:00
|
|
|
self.assertTrue("to_unix_timestamp" in functions)
|
|
|
|
self.assertTrue("current_database" in functions)
|
|
|
|
self.assertEquals(functions["+"], Function(
|
|
|
|
name="+",
|
|
|
|
description=None,
|
|
|
|
className="org.apache.spark.sql.catalyst.expressions.Add",
|
|
|
|
isTemporary=True))
|
2016-04-29 19:41:13 -04:00
|
|
|
self.assertEquals(functions, functionsDefault)
|
|
|
|
spark.catalog.registerFunction("temp_func", lambda x: str(x))
|
|
|
|
spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
|
|
|
|
spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
|
|
|
|
newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
|
|
|
|
newFunctionsSomeDb = dict((f.name, f) for f in spark.catalog.listFunctions("some_db"))
|
|
|
|
self.assertTrue(set(functions).issubset(set(newFunctions)))
|
|
|
|
self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
|
|
|
|
self.assertTrue("temp_func" in newFunctions)
|
|
|
|
self.assertTrue("func1" in newFunctions)
|
|
|
|
self.assertTrue("func2" not in newFunctions)
|
|
|
|
self.assertTrue("temp_func" in newFunctionsSomeDb)
|
|
|
|
self.assertTrue("func1" not in newFunctionsSomeDb)
|
|
|
|
self.assertTrue("func2" in newFunctionsSomeDb)
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.listFunctions("does_not_exist"))
|
|
|
|
|
|
|
|
def test_list_columns(self):
|
|
|
|
from pyspark.sql.catalog import Column
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
2017-01-22 23:37:37 -05:00
|
|
|
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
|
|
|
|
spark.sql("CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet")
|
2016-04-29 19:41:13 -04:00
|
|
|
columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
|
|
|
|
columnsDefault = sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
|
|
|
|
self.assertEquals(columns, columnsDefault)
|
|
|
|
self.assertEquals(len(columns), 2)
|
|
|
|
self.assertEquals(columns[0], Column(
|
|
|
|
name="age",
|
|
|
|
description=None,
|
|
|
|
dataType="int",
|
|
|
|
nullable=True,
|
|
|
|
isPartition=False,
|
|
|
|
isBucket=False))
|
|
|
|
self.assertEquals(columns[1], Column(
|
|
|
|
name="name",
|
|
|
|
description=None,
|
|
|
|
dataType="string",
|
|
|
|
nullable=True,
|
|
|
|
isPartition=False,
|
|
|
|
isBucket=False))
|
|
|
|
columns2 = sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
|
|
|
|
self.assertEquals(len(columns2), 2)
|
|
|
|
self.assertEquals(columns2[0], Column(
|
|
|
|
name="nickname",
|
|
|
|
description=None,
|
|
|
|
dataType="string",
|
|
|
|
nullable=True,
|
|
|
|
isPartition=False,
|
|
|
|
isBucket=False))
|
|
|
|
self.assertEquals(columns2[1], Column(
|
|
|
|
name="tolerance",
|
|
|
|
description=None,
|
|
|
|
dataType="float",
|
|
|
|
nullable=True,
|
|
|
|
isPartition=False,
|
|
|
|
isBucket=False))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"tab2",
|
|
|
|
lambda: spark.catalog.listColumns("tab2"))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.listColumns("does_not_exist"))
|
|
|
|
|
|
|
|
def test_cache(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-05-17 21:01:59 -04:00
|
|
|
spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1")
|
|
|
|
spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2")
|
2016-04-29 19:41:13 -04:00
|
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
|
|
spark.catalog.cacheTable("tab1")
|
|
|
|
self.assertTrue(spark.catalog.isCached("tab1"))
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
|
|
spark.catalog.cacheTable("tab2")
|
|
|
|
spark.catalog.uncacheTable("tab1")
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
|
|
self.assertTrue(spark.catalog.isCached("tab2"))
|
|
|
|
spark.catalog.clearCache()
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.isCached("does_not_exist"))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.cacheTable("does_not_exist"))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.uncacheTable("does_not_exist"))
|
|
|
|
|
2016-10-07 03:27:55 -04:00
|
|
|
def test_read_text_file_list(self):
|
|
|
|
df = self.spark.read.text(['python/test_support/sql/text-test.txt',
|
|
|
|
'python/test_support/sql/text-test.txt'])
|
|
|
|
count = df.count()
|
|
|
|
self.assertEquals(count, 4)
|
|
|
|
|
2016-10-11 02:29:52 -04:00
|
|
|
def test_BinaryType_serialization(self):
|
|
|
|
# Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808
|
2017-08-30 23:55:38 -04:00
|
|
|
# The empty bytearray is test for SPARK-21534.
|
2016-10-11 02:29:52 -04:00
|
|
|
schema = StructType([StructField('mybytes', BinaryType())])
|
|
|
|
data = [[bytearray(b'here is my data')],
|
2017-08-30 23:55:38 -04:00
|
|
|
[bytearray(b'and here is some more')],
|
|
|
|
[bytearray(b'')]]
|
2016-10-11 02:29:52 -04:00
|
|
|
df = self.spark.createDataFrame(data, schema=schema)
|
|
|
|
df.collect()
|
|
|
|
|
[SPARK-16542][SQL][PYSPARK] Fix bugs about types that result an array of null when creating DataFrame using python
## What changes were proposed in this pull request?
This is the reopen of https://github.com/apache/spark/pull/14198, with merge conflicts resolved.
ueshin Could you please take a look at my code?
Fix bugs about types that result an array of null when creating DataFrame using python.
Python's array.array have richer type than python itself, e.g. we can have `array('f',[1,2,3])` and `array('d',[1,2,3])`. Codes in spark-sql and pyspark didn't take this into consideration which might cause a problem that you get an array of null values when you have `array('f')` in your rows.
A simple code to reproduce this bug is:
```
from pyspark import SparkContext
from pyspark.sql import SQLContext,Row,DataFrame
from array import array
sc = SparkContext()
sqlContext = SQLContext(sc)
row1 = Row(floatarray=array('f',[1,2,3]), doublearray=array('d',[1,2,3]))
rows = sc.parallelize([ row1 ])
df = sqlContext.createDataFrame(rows)
df.show()
```
which have output
```
+---------------+------------------+
| doublearray| floatarray|
+---------------+------------------+
|[1.0, 2.0, 3.0]|[null, null, null]|
+---------------+------------------+
```
## How was this patch tested?
New test case added
Author: Xiang Gao <qasdfgtyuiop@gmail.com>
Author: Gao, Xiang <qasdfgtyuiop@gmail.com>
Author: Takuya UESHIN <ueshin@databricks.com>
Closes #18444 from zasdfgbnm/fix_array_infer.
2017-07-19 23:46:06 -04:00
|
|
|
# test for SPARK-16542
|
|
|
|
def test_array_types(self):
|
|
|
|
# This test need to make sure that the Scala type selected is at least
|
|
|
|
# as large as the python's types. This is necessary because python's
|
|
|
|
# array types depend on C implementation on the machine. Therefore there
|
|
|
|
# is no machine independent correspondence between python's array types
|
|
|
|
# and Scala types.
|
|
|
|
# See: https://docs.python.org/2/library/array.html
|
|
|
|
|
|
|
|
def assertCollectSuccess(typecode, value):
|
|
|
|
row = Row(myarray=array.array(typecode, [value]))
|
|
|
|
df = self.spark.createDataFrame([row])
|
|
|
|
self.assertEqual(df.first()["myarray"][0], value)
|
|
|
|
|
|
|
|
# supported string types
|
|
|
|
#
|
|
|
|
# String types in python's array are "u" for Py_UNICODE and "c" for char.
|
|
|
|
# "u" will be removed in python 4, and "c" is not supported in python 3.
|
|
|
|
supported_string_types = []
|
|
|
|
if sys.version_info[0] < 4:
|
|
|
|
supported_string_types += ['u']
|
|
|
|
# test unicode
|
|
|
|
assertCollectSuccess('u', u'a')
|
|
|
|
if sys.version_info[0] < 3:
|
|
|
|
supported_string_types += ['c']
|
|
|
|
# test string
|
|
|
|
assertCollectSuccess('c', 'a')
|
|
|
|
|
|
|
|
# supported float and double
|
|
|
|
#
|
|
|
|
# Test max, min, and precision for float and double, assuming IEEE 754
|
|
|
|
# floating-point format.
|
|
|
|
supported_fractional_types = ['f', 'd']
|
|
|
|
assertCollectSuccess('f', ctypes.c_float(1e+38).value)
|
|
|
|
assertCollectSuccess('f', ctypes.c_float(1e-38).value)
|
|
|
|
assertCollectSuccess('f', ctypes.c_float(1.123456).value)
|
|
|
|
assertCollectSuccess('d', sys.float_info.max)
|
|
|
|
assertCollectSuccess('d', sys.float_info.min)
|
|
|
|
assertCollectSuccess('d', sys.float_info.epsilon)
|
|
|
|
|
|
|
|
# supported signed int types
|
|
|
|
#
|
|
|
|
# The size of C types changes with implementation, we need to make sure
|
|
|
|
# that there is no overflow error on the platform running this test.
|
|
|
|
supported_signed_int_types = list(
|
|
|
|
set(_array_signed_int_typecode_ctype_mappings.keys())
|
|
|
|
.intersection(set(_array_type_mappings.keys())))
|
|
|
|
for t in supported_signed_int_types:
|
|
|
|
ctype = _array_signed_int_typecode_ctype_mappings[t]
|
|
|
|
max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
|
|
|
|
assertCollectSuccess(t, max_val - 1)
|
|
|
|
assertCollectSuccess(t, -max_val)
|
|
|
|
|
|
|
|
# supported unsigned int types
|
|
|
|
#
|
|
|
|
# JVM does not have unsigned types. We need to be very careful to make
|
|
|
|
# sure that there is no overflow error.
|
|
|
|
supported_unsigned_int_types = list(
|
|
|
|
set(_array_unsigned_int_typecode_ctype_mappings.keys())
|
|
|
|
.intersection(set(_array_type_mappings.keys())))
|
|
|
|
for t in supported_unsigned_int_types:
|
|
|
|
ctype = _array_unsigned_int_typecode_ctype_mappings[t]
|
|
|
|
assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1)
|
|
|
|
|
|
|
|
# all supported types
|
|
|
|
#
|
|
|
|
# Make sure the types tested above:
|
|
|
|
# 1. are all supported types
|
|
|
|
# 2. cover all supported types
|
|
|
|
supported_types = (supported_string_types +
|
|
|
|
supported_fractional_types +
|
|
|
|
supported_signed_int_types +
|
|
|
|
supported_unsigned_int_types)
|
|
|
|
self.assertEqual(set(supported_types), set(_array_type_mappings.keys()))
|
|
|
|
|
|
|
|
# all unsupported types
|
|
|
|
#
|
|
|
|
# Keys in _array_type_mappings is a complete list of all supported types,
|
|
|
|
# and types not in _array_type_mappings are considered unsupported.
|
|
|
|
# `array.typecodes` are not supported in python 2.
|
|
|
|
if sys.version_info[0] < 3:
|
|
|
|
all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd'])
|
|
|
|
else:
|
|
|
|
all_types = set(array.typecodes)
|
|
|
|
unsupported_types = all_types - set(supported_types)
|
|
|
|
# test unsupported types
|
|
|
|
for t in unsupported_types:
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
a = array.array(t)
|
|
|
|
self.spark.createDataFrame([Row(myarray=a)]).collect()
|
|
|
|
|
2017-05-07 22:58:27 -04:00
|
|
|
def test_bucketed_write(self):
|
|
|
|
data = [
|
|
|
|
(1, "foo", 3.0), (2, "foo", 5.0),
|
|
|
|
(3, "bar", -1.0), (4, "bar", 6.0),
|
|
|
|
]
|
|
|
|
df = self.spark.createDataFrame(data, ["x", "y", "z"])
|
|
|
|
|
|
|
|
def count_bucketed_cols(names, table="pyspark_bucket"):
|
|
|
|
"""Given a sequence of column names and a table name
|
|
|
|
query the catalog and return number o columns which are
|
|
|
|
used for bucketing
|
|
|
|
"""
|
|
|
|
cols = self.spark.catalog.listColumns(table)
|
|
|
|
num = len([c for c in cols if c.name in names and c.isBucket])
|
|
|
|
return num
|
|
|
|
|
|
|
|
# Test write with one bucketing column
|
|
|
|
df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
|
|
|
|
self.assertEqual(count_bucketed_cols(["x"]), 1)
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write two bucketing columns
|
|
|
|
df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
|
|
|
|
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write with bucket and sort
|
|
|
|
df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
|
|
|
|
self.assertEqual(count_bucketed_cols(["x"]), 1)
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write with a list of columns
|
|
|
|
df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
|
|
|
|
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write with bucket and sort with a list of columns
|
|
|
|
(df.write.bucketBy(2, "x")
|
|
|
|
.sortBy(["y", "z"])
|
|
|
|
.mode("overwrite").saveAsTable("pyspark_bucket"))
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write with bucket and sort with multiple columns
|
|
|
|
(df.write.bucketBy(2, "x")
|
|
|
|
.sortBy("y", "z")
|
|
|
|
.mode("overwrite").saveAsTable("pyspark_bucket"))
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
2017-11-28 03:45:22 -05:00
|
|
|
def _to_pandas(self):
|
|
|
|
from datetime import datetime, date
|
2017-06-22 04:22:02 -04:00
|
|
|
schema = StructType().add("a", IntegerType()).add("b", StringType())\
|
2017-11-28 03:45:22 -05:00
|
|
|
.add("c", BooleanType()).add("d", FloatType())\
|
|
|
|
.add("dt", DateType()).add("ts", TimestampType())
|
2017-06-22 04:22:02 -04:00
|
|
|
data = [
|
2017-11-28 03:45:22 -05:00
|
|
|
(1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
|
|
|
|
(2, "foo", True, 5.0, None, None),
|
|
|
|
(3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
|
|
|
|
(4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
|
2017-06-22 04:22:02 -04:00
|
|
|
]
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
2017-11-28 03:45:22 -05:00
|
|
|
return df.toPandas()
|
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
|
2017-11-28 03:45:22 -05:00
|
|
|
def test_to_pandas(self):
|
|
|
|
import numpy as np
|
|
|
|
pdf = self._to_pandas()
|
|
|
|
types = pdf.dtypes
|
2017-06-22 04:22:02 -04:00
|
|
|
self.assertEquals(types[0], np.int32)
|
|
|
|
self.assertEquals(types[1], np.object)
|
|
|
|
self.assertEquals(types[2], np.bool)
|
|
|
|
self.assertEquals(types[3], np.float32)
|
2018-02-06 01:52:25 -05:00
|
|
|
self.assertEquals(types[4], np.object) # datetime.date
|
2017-11-28 03:45:22 -05:00
|
|
|
self.assertEquals(types[5], 'datetime64[ns]')
|
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
|
|
|
|
def test_to_pandas_required_pandas_not_found(self):
|
2017-11-28 03:45:22 -05:00
|
|
|
with QuietTest(self.sc):
|
2017-12-22 06:09:51 -05:00
|
|
|
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
|
2017-11-28 03:45:22 -05:00
|
|
|
self._to_pandas()
|
2017-06-22 04:22:02 -04:00
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
|
2017-09-22 09:39:47 -04:00
|
|
|
def test_to_pandas_avoid_astype(self):
|
|
|
|
import numpy as np
|
|
|
|
schema = StructType().add("a", IntegerType()).add("b", StringType())\
|
|
|
|
.add("c", IntegerType())
|
|
|
|
data = [(1, "foo", 16777220), (None, "bar", None)]
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
types = df.toPandas().dtypes
|
|
|
|
self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
|
|
|
|
self.assertEquals(types[1], np.object)
|
|
|
|
self.assertEquals(types[2], np.float64)
|
|
|
|
|
2017-07-07 01:05:22 -04:00
|
|
|
def test_create_dataframe_from_array_of_long(self):
|
|
|
|
import array
|
|
|
|
data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))]
|
|
|
|
df = self.spark.createDataFrame(data)
|
|
|
|
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
|
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
|
2017-11-07 15:32:37 -05:00
|
|
|
def test_create_dataframe_from_pandas_with_timestamp(self):
|
|
|
|
import pandas as pd
|
|
|
|
from datetime import datetime
|
|
|
|
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
|
|
|
|
"d": [pd.Timestamp.now().date()]})
|
|
|
|
# test types are inferred correctly without specifying schema
|
|
|
|
df = self.spark.createDataFrame(pdf)
|
|
|
|
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
|
|
|
|
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
|
|
|
|
# test with schema will accept pdf as input
|
|
|
|
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
|
|
|
|
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
|
|
|
|
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
|
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
|
|
|
|
def test_create_dataframe_required_pandas_not_found(self):
|
2017-11-28 03:45:22 -05:00
|
|
|
with QuietTest(self.sc):
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
ImportError,
|
2018-02-07 22:46:10 -05:00
|
|
|
"(Pandas >= .* must be installed|No module named '?pandas'?)"):
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
import pandas as pd
|
|
|
|
from datetime import datetime
|
|
|
|
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
|
|
|
|
"d": [pd.Timestamp.now().date()]})
|
2017-11-28 03:45:22 -05:00
|
|
|
self.spark.createDataFrame(pdf)
|
|
|
|
|
2018-02-10 11:08:02 -05:00
|
|
|
# Regression test for SPARK-23360
|
|
|
|
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
|
|
|
|
def test_create_dateframe_from_pandas_with_dst(self):
|
|
|
|
import pandas as pd
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame(pdf)
|
|
|
|
self.assertPandasEqual(pdf, df.toPandas())
|
|
|
|
|
|
|
|
orig_env_tz = os.environ.get('TZ', None)
|
|
|
|
orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone')
|
|
|
|
try:
|
|
|
|
tz = 'America/Los_Angeles'
|
|
|
|
os.environ['TZ'] = tz
|
|
|
|
time.tzset()
|
|
|
|
self.spark.conf.set('spark.sql.session.timeZone', tz)
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame(pdf)
|
|
|
|
self.assertPandasEqual(pdf, df.toPandas())
|
|
|
|
finally:
|
|
|
|
del os.environ['TZ']
|
|
|
|
if orig_env_tz is not None:
|
|
|
|
os.environ['TZ'] = orig_env_tz
|
|
|
|
time.tzset()
|
|
|
|
self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2016-06-28 10:54:44 -04:00
|
|
|
class HiveSparkSubmitTests(SparkSubmitTests):
|
|
|
|
|
|
|
|
def test_hivecontext(self):
|
|
|
|
# This test checks that HiveContext is using Hive metastore (SPARK-16224).
|
|
|
|
# It sets a metastore url and checks if there is a derby dir created by
|
|
|
|
# Hive metastore. If this derby dir exists, HiveContext is using
|
|
|
|
# Hive metastore.
|
|
|
|
metastore_path = os.path.join(tempfile.mkdtemp(), "spark16224_metastore_db")
|
|
|
|
metastore_URL = "jdbc:derby:;databaseName=" + metastore_path + ";create=true"
|
|
|
|
hive_site_dir = os.path.join(self.programDir, "conf")
|
|
|
|
hive_site_file = self.createTempFile("hive-site.xml", ("""
|
|
|
|
|<configuration>
|
|
|
|
| <property>
|
|
|
|
| <name>javax.jdo.option.ConnectionURL</name>
|
|
|
|
| <value>%s</value>
|
|
|
|
| </property>
|
|
|
|
|</configuration>
|
|
|
|
""" % metastore_URL).lstrip(), "conf")
|
|
|
|
script = self.createTempFile("test.py", """
|
|
|
|
|import os
|
|
|
|
|
|
|
|
|
|from pyspark.conf import SparkConf
|
|
|
|
|from pyspark.context import SparkContext
|
|
|
|
|from pyspark.sql import HiveContext
|
|
|
|
|
|
|
|
|
|conf = SparkConf()
|
|
|
|
|sc = SparkContext(conf=conf)
|
|
|
|
|hive_context = HiveContext(sc)
|
|
|
|
|print(hive_context.sql("show databases").collect())
|
|
|
|
""")
|
|
|
|
proc = subprocess.Popen(
|
|
|
|
[self.sparkSubmit, "--master", "local-cluster[1,1,1024]",
|
|
|
|
"--driver-class-path", hive_site_dir, script],
|
|
|
|
stdout=subprocess.PIPE)
|
|
|
|
out, err = proc.communicate()
|
|
|
|
self.assertEqual(0, proc.returncode)
|
|
|
|
self.assertIn("default", out.decode('utf-8'))
|
|
|
|
self.assertTrue(os.path.exists(metastore_path))
|
|
|
|
|
|
|
|
|
2017-10-29 22:50:22 -04:00
|
|
|
class SQLTests2(ReusedSQLTestCase):
|
2017-01-12 07:53:31 -05:00
|
|
|
|
|
|
|
# We can't include this test into SQLTests because we will stop class's SparkContext and cause
|
|
|
|
# other tests failed.
|
|
|
|
def test_sparksession_with_stopped_sparkcontext(self):
|
|
|
|
self.sc.stop()
|
|
|
|
sc = SparkContext('local[4]', self.sc.appName)
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
2017-09-08 01:26:07 -04:00
|
|
|
try:
|
|
|
|
df = spark.createDataFrame([(1, 2)], ["c", "c"])
|
|
|
|
df.collect()
|
|
|
|
finally:
|
|
|
|
spark.stop()
|
|
|
|
sc.stop()
|
2017-01-12 07:53:31 -05:00
|
|
|
|
|
|
|
|
2018-01-31 06:04:51 -05:00
|
|
|
class SparkSessionTests(PySparkTestCase):
|
|
|
|
|
|
|
|
# This test is separate because it's closely related with session's start and stop.
|
|
|
|
# See SPARK-23228.
|
|
|
|
def test_set_jvm_default_session(self):
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
|
|
|
try:
|
|
|
|
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
|
|
|
|
finally:
|
|
|
|
spark.stop()
|
|
|
|
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())
|
|
|
|
|
|
|
|
def test_jvm_default_session_already_set(self):
|
|
|
|
# Here, we assume there is the default session already set in JVM.
|
|
|
|
jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc())
|
|
|
|
self.sc._jvm.SparkSession.setDefaultSession(jsession)
|
|
|
|
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
|
|
|
try:
|
|
|
|
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
|
|
|
|
# The session should be the same with the exiting one.
|
|
|
|
self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
|
|
|
|
finally:
|
|
|
|
spark.stop()
|
|
|
|
|
|
|
|
|
2017-01-31 21:03:39 -05:00
|
|
|
class UDFInitializationTests(unittest.TestCase):
|
|
|
|
def tearDown(self):
|
|
|
|
if SparkSession._instantiatedSession is not None:
|
|
|
|
SparkSession._instantiatedSession.stop()
|
|
|
|
|
|
|
|
if SparkContext._active_spark_context is not None:
|
|
|
|
SparkContext._active_spark_contex.stop()
|
|
|
|
|
|
|
|
def test_udf_init_shouldnt_initalize_context(self):
|
|
|
|
from pyspark.sql.functions import UserDefinedFunction
|
|
|
|
|
|
|
|
UserDefinedFunction(lambda x: x, StringType())
|
|
|
|
|
|
|
|
self.assertIsNone(
|
|
|
|
SparkContext._active_spark_context,
|
|
|
|
"SparkContext shouldn't be initialized when UserDefinedFunction is created."
|
|
|
|
)
|
|
|
|
self.assertIsNone(
|
|
|
|
SparkSession._instantiatedSession,
|
|
|
|
"SparkSession shouldn't be initialized when UserDefinedFunction is created."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2015-02-10 20:29:52 -05:00
|
|
|
class HiveContextSQLTests(ReusedPySparkTestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ReusedPySparkTestCase.setUpClass()
|
|
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
2015-02-17 18:44:37 -05:00
|
|
|
try:
|
|
|
|
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
|
|
|
|
except py4j.protocol.Py4JError:
|
2015-06-03 03:23:34 -04:00
|
|
|
cls.tearDownClass()
|
2015-05-23 11:30:05 -04:00
|
|
|
raise unittest.SkipTest("Hive is not available")
|
2015-04-16 19:20:57 -04:00
|
|
|
except TypeError:
|
2015-06-03 03:23:34 -04:00
|
|
|
cls.tearDownClass()
|
2015-05-23 11:30:05 -04:00
|
|
|
raise unittest.SkipTest("Hive is not available")
|
2015-02-10 20:29:52 -05:00
|
|
|
os.unlink(cls.tempdir.name)
|
2016-05-11 14:24:16 -04:00
|
|
|
cls.spark = HiveContext._createForTesting(cls.sc)
|
2015-02-10 20:29:52 -05:00
|
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
2015-02-14 02:03:22 -05:00
|
|
|
cls.df = cls.sc.parallelize(cls.testData).toDF()
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ReusedPySparkTestCase.tearDownClass()
|
|
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
|
|
|
|
def test_save_and_load_table(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json")
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("DROP TABLE externalJsonTable")
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
|
2015-02-10 20:29:52 -05:00
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.createExternalTable("externalJsonTable", source="json",
|
|
|
|
schema=schema, path=tmpPath,
|
|
|
|
noUse="this options will not be used")
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.select("value").collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("DROP TABLE savedJsonTable")
|
|
|
|
self.spark.sql("DROP TABLE externalJsonTable")
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
defaultDataSourceName = self.spark.getConf("spark.sql.sources.default",
|
|
|
|
"org.apache.spark.sql.parquet")
|
|
|
|
self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("DROP TABLE savedJsonTable")
|
|
|
|
self.spark.sql("DROP TABLE externalJsonTable")
|
|
|
|
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2015-05-23 11:30:05 -04:00
|
|
|
def test_window_functions(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
2015-05-23 11:30:05 -04:00
|
|
|
w = Window.partitionBy("value").orderBy("key")
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
sel = df.select(df.value, df.key,
|
|
|
|
F.max("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.min("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.row_number().over(w),
|
2015-05-23 11:30:05 -04:00
|
|
|
F.rank().over(w),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.dense_rank().over(w),
|
2015-05-23 11:30:05 -04:00
|
|
|
F.ntile(2).over(w))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [
|
|
|
|
("1", 1, 1, 1, 1, 1, 1, 1, 1),
|
|
|
|
("2", 1, 1, 1, 3, 1, 1, 1, 1),
|
|
|
|
("2", 1, 2, 1, 3, 2, 1, 1, 1),
|
|
|
|
("2", 2, 2, 2, 3, 3, 3, 2, 2)
|
|
|
|
]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
2015-05-19 17:23:28 -04:00
|
|
|
|
2015-08-14 16:55:29 -04:00
|
|
|
def test_window_functions_without_partitionBy(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
2015-08-14 16:55:29 -04:00
|
|
|
w = Window.orderBy("key", df.value)
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
sel = df.select(df.value, df.key,
|
|
|
|
F.max("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.min("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.row_number().over(w),
|
2015-08-14 16:55:29 -04:00
|
|
|
F.rank().over(w),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.dense_rank().over(w),
|
2015-08-14 16:55:29 -04:00
|
|
|
F.ntile(2).over(w))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [
|
|
|
|
("1", 1, 1, 1, 4, 1, 1, 1, 1),
|
|
|
|
("2", 1, 1, 1, 4, 2, 2, 2, 1),
|
|
|
|
("2", 1, 2, 1, 4, 3, 2, 2, 2),
|
|
|
|
("2", 2, 2, 2, 4, 4, 4, 3, 2)
|
|
|
|
]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
2016-10-11 01:33:20 -04:00
|
|
|
def test_window_functions_cumulative_sum(self):
|
|
|
|
df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"])
|
|
|
|
from pyspark.sql import functions as F
|
[SPARK-17845] [SQL] More self-evident window function frame boundary API
## What changes were proposed in this pull request?
This patch improves the window function frame boundary API to make it more obvious to read and to use. The two high level changes are:
1. Create Window.currentRow, Window.unboundedPreceding, Window.unboundedFollowing to indicate the special values in frame boundaries. These methods map to the special integral values so we are not breaking backward compatibility here. This change makes the frame boundaries more self-evident (instead of Long.MinValue, it becomes Window.unboundedPreceding).
2. In Python, for any value less than or equal to JVM's Long.MinValue, treat it as Window.unboundedPreceding. For any value larger than or equal to JVM's Long.MaxValue, treat it as Window.unboundedFollowing. Before this change, if the user specifies any value that is less than Long.MinValue but not -sys.maxsize (e.g. -sys.maxsize + 1), the number we pass over to the JVM would overflow, resulting in a frame that does not make sense.
Code example required to specify a frame before this patch:
```
Window.rowsBetween(-Long.MinValue, 0)
```
While the above code should still work, the new way is more obvious to read:
```
Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
```
## How was this patch tested?
- Updated DataFrameWindowSuite (for Scala/Java)
- Updated test_window_functions_cumulative_sum (for Python)
- Renamed DataFrameWindowSuite DataFrameWindowFunctionsSuite to better reflect its purpose
Author: Reynold Xin <rxin@databricks.com>
Closes #15438 from rxin/SPARK-17845.
2016-10-12 19:45:10 -04:00
|
|
|
|
|
|
|
# Test cumulative sum
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0)))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 1), ("two", 3)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
|
|
|
# Test boundary values less than JVM's Long.MinValue and make sure we don't overflow
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0)))
|
2016-10-11 01:33:20 -04:00
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 1), ("two", 3)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
[SPARK-17845] [SQL] More self-evident window function frame boundary API
## What changes were proposed in this pull request?
This patch improves the window function frame boundary API to make it more obvious to read and to use. The two high level changes are:
1. Create Window.currentRow, Window.unboundedPreceding, Window.unboundedFollowing to indicate the special values in frame boundaries. These methods map to the special integral values so we are not breaking backward compatibility here. This change makes the frame boundaries more self-evident (instead of Long.MinValue, it becomes Window.unboundedPreceding).
2. In Python, for any value less than or equal to JVM's Long.MinValue, treat it as Window.unboundedPreceding. For any value larger than or equal to JVM's Long.MaxValue, treat it as Window.unboundedFollowing. Before this change, if the user specifies any value that is less than Long.MinValue but not -sys.maxsize (e.g. -sys.maxsize + 1), the number we pass over to the JVM would overflow, resulting in a frame that does not make sense.
Code example required to specify a frame before this patch:
```
Window.rowsBetween(-Long.MinValue, 0)
```
While the above code should still work, the new way is more obvious to read:
```
Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
```
## How was this patch tested?
- Updated DataFrameWindowSuite (for Scala/Java)
- Updated test_window_functions_cumulative_sum (for Python)
- Renamed DataFrameWindowSuite DataFrameWindowFunctionsSuite to better reflect its purpose
Author: Reynold Xin <rxin@databricks.com>
Closes #15438 from rxin/SPARK-17845.
2016-10-12 19:45:10 -04:00
|
|
|
# Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow
|
|
|
|
frame_end = Window.unboundedFollowing + 1
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end)))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 3), ("two", 2)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
2015-11-09 17:30:37 -05:00
|
|
|
def test_collect_functions(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
2015-11-09 17:30:37 -05:00
|
|
|
from pyspark.sql import functions
|
|
|
|
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
|
|
|
|
[1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
|
|
|
|
[1, 1, 1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
|
|
|
|
["1", "2"])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
|
|
|
|
["1", "2", "2", "2"])
|
|
|
|
|
[SPARK-17514] df.take(1) and df.limit(1).collect() should perform the same in Python
## What changes were proposed in this pull request?
In PySpark, `df.take(1)` runs a single-stage job which computes only one partition of the DataFrame, while `df.limit(1).collect()` computes all partitions and runs a two-stage job. This difference in performance is confusing.
The reason why `limit(1).collect()` is so much slower is that `collect()` internally maps to `df.rdd.<some-pyspark-conversions>.toLocalIterator`, which causes Spark SQL to build a query where a global limit appears in the middle of the plan; this, in turn, ends up being executed inefficiently because limits in the middle of plans are now implemented by repartitioning to a single task rather than by running a `take()` job on the driver (this was done in #7334, a patch which was a prerequisite to allowing partition-local limits to be pushed beneath unions, etc.).
In order to fix this performance problem I think that we should generalize the fix from SPARK-10731 / #8876 so that `DataFrame.collect()` also delegates to the Scala implementation and shares the same performance properties. This patch modifies `DataFrame.collect()` to first collect all results to the driver and then pass them to Python, allowing this query to be planned using Spark's `CollectLimit` optimizations.
## How was this patch tested?
Added a regression test in `sql/tests.py` which asserts that the expected number of jobs, stages, and tasks are run for both queries.
Author: Josh Rosen <joshrosen@databricks.com>
Closes #15068 from JoshRosen/pyspark-collect-limit.
2016-09-14 13:10:01 -04:00
|
|
|
def test_limit_and_take(self):
|
|
|
|
df = self.spark.range(1, 1000, numPartitions=10)
|
|
|
|
|
|
|
|
def assert_runs_only_one_job_stage_and_task(job_group_name, f):
|
|
|
|
tracker = self.sc.statusTracker()
|
|
|
|
self.sc.setJobGroup(job_group_name, description="")
|
|
|
|
f()
|
|
|
|
jobs = tracker.getJobIdsForGroup(job_group_name)
|
|
|
|
self.assertEqual(1, len(jobs))
|
|
|
|
stages = tracker.getJobInfo(jobs[0]).stageIds
|
|
|
|
self.assertEqual(1, len(stages))
|
|
|
|
self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks)
|
|
|
|
|
|
|
|
# Regression test for SPARK-10731: take should delegate to Scala implementation
|
|
|
|
assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1))
|
|
|
|
# Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n)
|
|
|
|
assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect())
|
|
|
|
|
2017-02-07 09:50:30 -05:00
|
|
|
def test_datetime_functions(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
from datetime import date, datetime
|
|
|
|
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
|
|
|
|
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
|
[SPARK-20639][SQL] Add single argument support for to_timestamp in SQL with documentation improvement
## What changes were proposed in this pull request?
This PR proposes three things as below:
- Use casting rules to a timestamp in `to_timestamp` by default (it was `yyyy-MM-dd HH:mm:ss`).
- Support single argument for `to_timestamp` similarly with APIs in other languages.
For example, the one below works
```
import org.apache.spark.sql.functions._
Seq("2016-12-31 00:12:00.00").toDF("a").select(to_timestamp(col("a"))).show()
```
prints
```
+----------------------------------------+
|to_timestamp(`a`, 'yyyy-MM-dd HH:mm:ss')|
+----------------------------------------+
| 2016-12-31 00:12:00|
+----------------------------------------+
```
whereas this does not work in SQL.
**Before**
```
spark-sql> SELECT to_timestamp('2016-12-31 00:12:00');
Error in query: Invalid number of arguments for function to_timestamp; line 1 pos 7
```
**After**
```
spark-sql> SELECT to_timestamp('2016-12-31 00:12:00');
2016-12-31 00:12:00
```
- Related document improvement for SQL function descriptions and other API descriptions accordingly.
**Before**
```
spark-sql> DESCRIBE FUNCTION extended to_date;
...
Usage: to_date(date_str, fmt) - Parses the `left` expression with the `fmt` expression. Returns null with invalid input.
Extended Usage:
Examples:
> SELECT to_date('2016-12-31', 'yyyy-MM-dd');
2016-12-31
```
```
spark-sql> DESCRIBE FUNCTION extended to_timestamp;
...
Usage: to_timestamp(timestamp, fmt) - Parses the `left` expression with the `format` expression to a timestamp. Returns null with invalid input.
Extended Usage:
Examples:
> SELECT to_timestamp('2016-12-31', 'yyyy-MM-dd');
2016-12-31 00:00:00.0
```
**After**
```
spark-sql> DESCRIBE FUNCTION extended to_date;
...
Usage:
to_date(date_str[, fmt]) - Parses the `date_str` expression with the `fmt` expression to
a date. Returns null with invalid input. By default, it follows casting rules to a date if
the `fmt` is omitted.
Extended Usage:
Examples:
> SELECT to_date('2009-07-30 04:17:52');
2009-07-30
> SELECT to_date('2016-12-31', 'yyyy-MM-dd');
2016-12-31
```
```
spark-sql> DESCRIBE FUNCTION extended to_timestamp;
...
Usage:
to_timestamp(timestamp[, fmt]) - Parses the `timestamp` expression with the `fmt` expression to
a timestamp. Returns null with invalid input. By default, it follows casting rules to
a timestamp if the `fmt` is omitted.
Extended Usage:
Examples:
> SELECT to_timestamp('2016-12-31 00:12:00');
2016-12-31 00:12:00
> SELECT to_timestamp('2016-12-31', 'yyyy-MM-dd');
2016-12-31 00:00:00
```
## How was this patch tested?
Added tests in `datetime.sql`.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #17901 from HyukjinKwon/to_timestamp_arg.
2017-05-12 04:42:58 -04:00
|
|
|
self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)'])
|
2017-02-07 09:50:30 -05:00
|
|
|
|
2016-12-02 20:39:28 -05:00
|
|
|
@unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't support mocking")
|
|
|
|
def test_unbounded_frames(self):
|
|
|
|
from unittest.mock import patch
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
from pyspark.sql import window
|
|
|
|
import importlib
|
|
|
|
|
|
|
|
df = self.spark.range(0, 3)
|
|
|
|
|
|
|
|
def rows_frame_match():
|
|
|
|
return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select(
|
|
|
|
F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize))
|
|
|
|
).columns[0]
|
|
|
|
|
|
|
|
def range_frame_match():
|
|
|
|
return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select(
|
|
|
|
F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize))
|
|
|
|
).columns[0]
|
|
|
|
|
|
|
|
with patch("sys.maxsize", 2 ** 31 - 1):
|
|
|
|
importlib.reload(window)
|
|
|
|
self.assertTrue(rows_frame_match())
|
|
|
|
self.assertTrue(range_frame_match())
|
|
|
|
|
|
|
|
with patch("sys.maxsize", 2 ** 63 - 1):
|
|
|
|
importlib.reload(window)
|
|
|
|
self.assertTrue(rows_frame_match())
|
|
|
|
self.assertTrue(range_frame_match())
|
|
|
|
|
|
|
|
with patch("sys.maxsize", 2 ** 127 - 1):
|
|
|
|
importlib.reload(window)
|
|
|
|
self.assertTrue(rows_frame_match())
|
|
|
|
self.assertTrue(range_frame_match())
|
|
|
|
|
|
|
|
importlib.reload(window)
|
2015-08-14 16:55:29 -04:00
|
|
|
|
[SPARK-19507][SPARK-21296][PYTHON] Avoid per-record type dispatch in schema verification and improve exception message
## What changes were proposed in this pull request?
**Context**
While reviewing https://github.com/apache/spark/pull/17227, I realised here we type-dispatch per record. The PR itself is fine in terms of performance as is but this prints a prefix, `"obj"` in exception message as below:
```
from pyspark.sql.types import *
schema = StructType([StructField('s', IntegerType(), nullable=False)])
spark.createDataFrame([["1"]], schema)
...
TypeError: obj.s: IntegerType can not accept object '1' in type <type 'str'>
```
I suggested to get rid of this but during investigating this, I realised my approach might bring a performance regression as it is a hot path.
Only for SPARK-19507 and https://github.com/apache/spark/pull/17227, It needs more changes to cleanly get rid of the prefix and I rather decided to fix both issues together.
**Propersal**
This PR tried to
- get rid of per-record type dispatch as we do in many code paths in Scala so that it improves the performance (roughly ~25% improvement) - SPARK-21296
This was tested with a simple code `spark.createDataFrame(range(1000000), "int")`. However, I am quite sure the actual improvement in practice is larger than this, in particular, when the schema is complicated.
- improve error message in exception describing field information as prose - SPARK-19507
## How was this patch tested?
Manually tested and unit tests were added in `python/pyspark/sql/tests.py`.
Benchmark - codes: https://gist.github.com/HyukjinKwon/c3397469c56cb26c2d7dd521ed0bc5a3
Error message - codes: https://gist.github.com/HyukjinKwon/b1b2c7f65865444c4a8836435100e398
**Before**
Benchmark:
- Results: https://gist.github.com/HyukjinKwon/4a291dab45542106301a0c1abcdca924
Error message
- Results: https://gist.github.com/HyukjinKwon/57b1916395794ce924faa32b14a3fe19
**After**
Benchmark
- Results: https://gist.github.com/HyukjinKwon/21496feecc4a920e50c4e455f836266e
Error message
- Results: https://gist.github.com/HyukjinKwon/7a494e4557fe32a652ce1236e504a395
Closes #17227
Author: hyukjinkwon <gurwls223@gmail.com>
Author: David Gingrich <david@textio.com>
Closes #18521 from HyukjinKwon/python-type-dispatch.
2017-07-04 08:45:58 -04:00
|
|
|
|
|
|
|
class DataTypeVerificationTests(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_verify_type_exception_msg(self):
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
ValueError,
|
|
|
|
"test_name",
|
|
|
|
lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None))
|
|
|
|
|
|
|
|
schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
TypeError,
|
|
|
|
"field b in field a",
|
|
|
|
lambda: _make_type_verifier(schema)([["data"]]))
|
|
|
|
|
|
|
|
def test_verify_type_ok_nullable(self):
|
|
|
|
obj = None
|
|
|
|
types = [IntegerType(), FloatType(), StringType(), StructType([])]
|
|
|
|
for data_type in types:
|
|
|
|
try:
|
|
|
|
_make_type_verifier(data_type, nullable=True)(obj)
|
|
|
|
except Exception:
|
|
|
|
self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type))
|
|
|
|
|
|
|
|
def test_verify_type_not_nullable(self):
|
|
|
|
import array
|
|
|
|
import datetime
|
|
|
|
import decimal
|
|
|
|
|
|
|
|
schema = StructType([
|
|
|
|
StructField('s', StringType(), nullable=False),
|
|
|
|
StructField('i', IntegerType(), nullable=True)])
|
|
|
|
|
|
|
|
class MyObj:
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
setattr(self, k, v)
|
|
|
|
|
|
|
|
# obj, data_type
|
|
|
|
success_spec = [
|
|
|
|
# String
|
|
|
|
("", StringType()),
|
|
|
|
(u"", StringType()),
|
|
|
|
(1, StringType()),
|
|
|
|
(1.0, StringType()),
|
|
|
|
([], StringType()),
|
|
|
|
({}, StringType()),
|
|
|
|
|
|
|
|
# UDT
|
|
|
|
(ExamplePoint(1.0, 2.0), ExamplePointUDT()),
|
|
|
|
|
|
|
|
# Boolean
|
|
|
|
(True, BooleanType()),
|
|
|
|
|
|
|
|
# Byte
|
|
|
|
(-(2**7), ByteType()),
|
|
|
|
(2**7 - 1, ByteType()),
|
|
|
|
|
|
|
|
# Short
|
|
|
|
(-(2**15), ShortType()),
|
|
|
|
(2**15 - 1, ShortType()),
|
|
|
|
|
|
|
|
# Integer
|
|
|
|
(-(2**31), IntegerType()),
|
|
|
|
(2**31 - 1, IntegerType()),
|
|
|
|
|
|
|
|
# Long
|
|
|
|
(2**64, LongType()),
|
|
|
|
|
|
|
|
# Float & Double
|
|
|
|
(1.0, FloatType()),
|
|
|
|
(1.0, DoubleType()),
|
|
|
|
|
|
|
|
# Decimal
|
|
|
|
(decimal.Decimal("1.0"), DecimalType()),
|
|
|
|
|
|
|
|
# Binary
|
|
|
|
(bytearray([1, 2]), BinaryType()),
|
|
|
|
|
|
|
|
# Date/Timestamp
|
|
|
|
(datetime.date(2000, 1, 2), DateType()),
|
|
|
|
(datetime.datetime(2000, 1, 2, 3, 4), DateType()),
|
|
|
|
(datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
|
|
|
|
|
|
|
|
# Array
|
|
|
|
([], ArrayType(IntegerType())),
|
|
|
|
(["1", None], ArrayType(StringType(), containsNull=True)),
|
|
|
|
([1, 2], ArrayType(IntegerType())),
|
|
|
|
((1, 2), ArrayType(IntegerType())),
|
|
|
|
(array.array('h', [1, 2]), ArrayType(IntegerType())),
|
|
|
|
|
|
|
|
# Map
|
|
|
|
({}, MapType(StringType(), IntegerType())),
|
|
|
|
({"a": 1}, MapType(StringType(), IntegerType())),
|
|
|
|
({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)),
|
|
|
|
|
|
|
|
# Struct
|
|
|
|
({"s": "a", "i": 1}, schema),
|
|
|
|
({"s": "a", "i": None}, schema),
|
|
|
|
({"s": "a"}, schema),
|
|
|
|
({"s": "a", "f": 1.0}, schema),
|
|
|
|
(Row(s="a", i=1), schema),
|
|
|
|
(Row(s="a", i=None), schema),
|
|
|
|
(Row(s="a", i=1, f=1.0), schema),
|
|
|
|
(["a", 1], schema),
|
|
|
|
(["a", None], schema),
|
|
|
|
(("a", 1), schema),
|
|
|
|
(MyObj(s="a", i=1), schema),
|
|
|
|
(MyObj(s="a", i=None), schema),
|
|
|
|
(MyObj(s="a"), schema),
|
|
|
|
]
|
|
|
|
|
|
|
|
# obj, data_type, exception class
|
|
|
|
failure_spec = [
|
|
|
|
# String (match anything but None)
|
|
|
|
(None, StringType(), ValueError),
|
|
|
|
|
|
|
|
# UDT
|
|
|
|
(ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
|
|
|
|
|
|
|
|
# Boolean
|
|
|
|
(1, BooleanType(), TypeError),
|
|
|
|
("True", BooleanType(), TypeError),
|
|
|
|
([1], BooleanType(), TypeError),
|
|
|
|
|
|
|
|
# Byte
|
|
|
|
(-(2**7) - 1, ByteType(), ValueError),
|
|
|
|
(2**7, ByteType(), ValueError),
|
|
|
|
("1", ByteType(), TypeError),
|
|
|
|
(1.0, ByteType(), TypeError),
|
|
|
|
|
|
|
|
# Short
|
|
|
|
(-(2**15) - 1, ShortType(), ValueError),
|
|
|
|
(2**15, ShortType(), ValueError),
|
|
|
|
|
|
|
|
# Integer
|
|
|
|
(-(2**31) - 1, IntegerType(), ValueError),
|
|
|
|
(2**31, IntegerType(), ValueError),
|
|
|
|
|
|
|
|
# Float & Double
|
|
|
|
(1, FloatType(), TypeError),
|
|
|
|
(1, DoubleType(), TypeError),
|
|
|
|
|
|
|
|
# Decimal
|
|
|
|
(1.0, DecimalType(), TypeError),
|
|
|
|
(1, DecimalType(), TypeError),
|
|
|
|
("1.0", DecimalType(), TypeError),
|
|
|
|
|
|
|
|
# Binary
|
|
|
|
(1, BinaryType(), TypeError),
|
|
|
|
|
|
|
|
# Date/Timestamp
|
|
|
|
("2000-01-02", DateType(), TypeError),
|
|
|
|
(946811040, TimestampType(), TypeError),
|
|
|
|
|
|
|
|
# Array
|
|
|
|
(["1", None], ArrayType(StringType(), containsNull=False), ValueError),
|
|
|
|
([1, "2"], ArrayType(IntegerType()), TypeError),
|
|
|
|
|
|
|
|
# Map
|
|
|
|
({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
|
|
|
|
({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
|
|
|
|
({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False),
|
|
|
|
ValueError),
|
|
|
|
|
|
|
|
# Struct
|
|
|
|
({"s": "a", "i": "1"}, schema, TypeError),
|
|
|
|
(Row(s="a"), schema, ValueError), # Row can't have missing field
|
|
|
|
(Row(s="a", i="1"), schema, TypeError),
|
|
|
|
(["a"], schema, ValueError),
|
|
|
|
(["a", "1"], schema, TypeError),
|
|
|
|
(MyObj(s="a", i="1"), schema, TypeError),
|
|
|
|
(MyObj(s=None, i="1"), schema, ValueError),
|
|
|
|
]
|
|
|
|
|
|
|
|
# Check success cases
|
|
|
|
for obj, data_type in success_spec:
|
|
|
|
try:
|
|
|
|
_make_type_verifier(data_type, nullable=False)(obj)
|
|
|
|
except Exception:
|
|
|
|
self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))
|
|
|
|
|
|
|
|
# Check failure cases
|
|
|
|
for obj, data_type, exp in failure_spec:
|
|
|
|
msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
|
|
|
|
with self.assertRaises(exp, msg=msg):
|
|
|
|
_make_type_verifier(data_type, nullable=False)(obj)
|
|
|
|
|
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(
|
|
|
|
not _have_pandas or not _have_pyarrow,
|
|
|
|
_pandas_requirement_message or _pyarrow_requirement_message)
|
2017-10-29 22:50:22 -04:00
|
|
|
class ArrowTests(ReusedSQLTestCase):
|
2017-07-10 18:21:03 -04:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
2018-02-06 01:52:25 -05:00
|
|
|
from datetime import date, datetime
|
2017-12-26 07:37:25 -05:00
|
|
|
from decimal import Decimal
|
2017-10-29 22:50:22 -04:00
|
|
|
ReusedSQLTestCase.setUpClass()
|
2017-10-27 02:02:46 -04:00
|
|
|
|
|
|
|
# Synchronize default timezone between Python and Java
|
|
|
|
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
|
|
|
|
tz = "America/Los_Angeles"
|
|
|
|
os.environ["TZ"] = tz
|
|
|
|
time.tzset()
|
|
|
|
|
|
|
|
cls.spark.conf.set("spark.sql.session.timeZone", tz)
|
2017-10-10 01:35:34 -04:00
|
|
|
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
# Disable fallback by default to easily detect the failures.
|
|
|
|
cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
|
2017-07-10 18:21:03 -04:00
|
|
|
cls.schema = StructType([
|
|
|
|
StructField("1_str_t", StringType(), True),
|
|
|
|
StructField("2_int_t", IntegerType(), True),
|
|
|
|
StructField("3_long_t", LongType(), True),
|
|
|
|
StructField("4_float_t", FloatType(), True),
|
2017-10-27 02:02:46 -04:00
|
|
|
StructField("5_double_t", DoubleType(), True),
|
2017-12-26 07:37:25 -05:00
|
|
|
StructField("6_decimal_t", DecimalType(38, 18), True),
|
|
|
|
StructField("7_date_t", DateType(), True),
|
|
|
|
StructField("8_timestamp_t", TimestampType(), True)])
|
|
|
|
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
|
2018-02-06 01:52:25 -05:00
|
|
|
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
|
2017-12-26 07:37:25 -05:00
|
|
|
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
|
2018-02-06 01:52:25 -05:00
|
|
|
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
|
2017-12-26 07:37:25 -05:00
|
|
|
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
|
2018-02-06 01:52:25 -05:00
|
|
|
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
|
2017-10-27 02:02:46 -04:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
del os.environ["TZ"]
|
|
|
|
if cls.tz_prev is not None:
|
|
|
|
os.environ["TZ"] = cls.tz_prev
|
|
|
|
time.tzset()
|
2017-10-29 22:50:22 -04:00
|
|
|
ReusedSQLTestCase.tearDownClass()
|
2017-07-10 18:21:03 -04:00
|
|
|
|
2017-11-12 23:16:01 -05:00
|
|
|
def create_pandas_data_frame(self):
|
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
data_dict = {}
|
|
|
|
for j, name in enumerate(self.schema.names):
|
|
|
|
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
|
|
|
|
# need to convert these to numpy types first
|
|
|
|
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
|
|
|
|
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
|
|
|
|
return pd.DataFrame(data=data_dict)
|
|
|
|
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
def test_toPandas_fallback_enabled(self):
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
|
|
|
|
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
|
|
|
|
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with warnings.catch_warnings(record=True) as warns:
|
|
|
|
pdf = df.toPandas()
|
|
|
|
# Catch and check the last UserWarning.
|
|
|
|
user_warns = [
|
|
|
|
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
|
|
|
|
self.assertTrue(len(user_warns) > 0)
|
|
|
|
self.assertTrue(
|
|
|
|
"Attempts non-optimization" in _exception_message(user_warns[-1]))
|
|
|
|
self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
|
|
|
|
|
|
|
|
def test_toPandas_fallback_disabled(self):
|
2017-12-26 07:37:25 -05:00
|
|
|
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
|
2017-10-27 02:02:46 -04:00
|
|
|
df = self.spark.createDataFrame([(None,)], schema=schema)
|
2017-07-10 18:21:03 -04:00
|
|
|
with QuietTest(self.sc):
|
2018-02-16 12:41:17 -05:00
|
|
|
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
|
|
|
|
df.toPandas()
|
|
|
|
|
2017-07-10 18:21:03 -04:00
|
|
|
def test_null_conversion(self):
|
|
|
|
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
|
|
|
|
self.data)
|
|
|
|
pdf = df_null.toPandas()
|
|
|
|
null_counts = pdf.isnull().sum().tolist()
|
|
|
|
self.assertTrue(all([c == 1 for c in null_counts]))
|
|
|
|
|
2017-11-28 03:45:22 -05:00
|
|
|
def _toPandas_arrow_toggle(self, df):
|
2017-10-10 01:35:34 -04:00
|
|
|
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
|
2017-11-12 23:16:01 -05:00
|
|
|
try:
|
|
|
|
pdf = df.toPandas()
|
|
|
|
finally:
|
|
|
|
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
|
2017-07-10 18:21:03 -04:00
|
|
|
pdf_arrow = df.toPandas()
|
2017-11-28 03:45:22 -05:00
|
|
|
return pdf, pdf_arrow
|
|
|
|
|
|
|
|
def test_toPandas_arrow_toggle(self):
|
|
|
|
df = self.spark.createDataFrame(self.data, schema=self.schema)
|
|
|
|
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
|
2018-02-06 01:52:25 -05:00
|
|
|
expected = self.create_pandas_data_frame()
|
|
|
|
self.assertPandasEqual(expected, pdf)
|
|
|
|
self.assertPandasEqual(expected, pdf_arrow)
|
2017-07-10 18:21:03 -04:00
|
|
|
|
2017-11-28 03:45:22 -05:00
|
|
|
def test_toPandas_respect_session_timezone(self):
|
|
|
|
df = self.spark.createDataFrame(self.data, schema=self.schema)
|
|
|
|
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
|
|
|
|
try:
|
|
|
|
timezone = "America/New_York"
|
|
|
|
self.spark.conf.set("spark.sql.session.timeZone", timezone)
|
|
|
|
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
|
|
|
|
try:
|
|
|
|
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(pdf_arrow_la, pdf_la)
|
2017-11-28 03:45:22 -05:00
|
|
|
finally:
|
|
|
|
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
|
|
|
|
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
|
2017-11-28 03:45:22 -05:00
|
|
|
|
|
|
|
self.assertFalse(pdf_ny.equals(pdf_la))
|
|
|
|
|
|
|
|
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
|
|
|
|
pdf_la_corrected = pdf_la.copy()
|
|
|
|
for field in self.schema:
|
|
|
|
if isinstance(field.dataType, TimestampType):
|
|
|
|
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
|
|
|
|
pdf_la_corrected[field.name], timezone)
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
|
2017-11-28 03:45:22 -05:00
|
|
|
finally:
|
|
|
|
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
|
|
|
|
|
2017-07-10 18:21:03 -04:00
|
|
|
def test_pandas_round_trip(self):
|
2017-11-12 23:16:01 -05:00
|
|
|
pdf = self.create_pandas_data_frame()
|
2017-07-10 18:21:03 -04:00
|
|
|
df = self.spark.createDataFrame(self.data, schema=self.schema)
|
|
|
|
pdf_arrow = df.toPandas()
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(pdf_arrow, pdf)
|
2017-07-10 18:21:03 -04:00
|
|
|
|
|
|
|
def test_filtered_frame(self):
|
|
|
|
df = self.spark.range(3).toDF("i")
|
|
|
|
pdf = df.filter("i < 0").toPandas()
|
|
|
|
self.assertEqual(len(pdf.columns), 1)
|
|
|
|
self.assertEqual(pdf.columns[0], "i")
|
|
|
|
self.assertTrue(pdf.empty)
|
|
|
|
|
2017-11-28 03:45:22 -05:00
|
|
|
def _createDataFrame_toggle(self, pdf, schema=None):
|
2017-11-12 23:16:01 -05:00
|
|
|
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
|
|
|
|
try:
|
2017-11-28 03:45:22 -05:00
|
|
|
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
|
2017-11-12 23:16:01 -05:00
|
|
|
finally:
|
|
|
|
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
|
2017-11-28 03:45:22 -05:00
|
|
|
df_arrow = self.spark.createDataFrame(pdf, schema=schema)
|
|
|
|
return df_no_arrow, df_arrow
|
|
|
|
|
|
|
|
def test_createDataFrame_toggle(self):
|
|
|
|
pdf = self.create_pandas_data_frame()
|
|
|
|
df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema)
|
2017-11-12 23:16:01 -05:00
|
|
|
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
|
|
|
|
|
2017-11-28 03:45:22 -05:00
|
|
|
def test_createDataFrame_respect_session_timezone(self):
|
|
|
|
from datetime import timedelta
|
|
|
|
pdf = self.create_pandas_data_frame()
|
|
|
|
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
|
|
|
|
try:
|
|
|
|
timezone = "America/New_York"
|
|
|
|
self.spark.conf.set("spark.sql.session.timeZone", timezone)
|
|
|
|
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
|
|
|
|
try:
|
|
|
|
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
|
|
|
|
result_la = df_no_arrow_la.collect()
|
|
|
|
result_arrow_la = df_arrow_la.collect()
|
|
|
|
self.assertEqual(result_la, result_arrow_la)
|
|
|
|
finally:
|
|
|
|
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
|
|
|
|
df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
|
|
|
|
result_ny = df_no_arrow_ny.collect()
|
|
|
|
result_arrow_ny = df_arrow_ny.collect()
|
|
|
|
self.assertEqual(result_ny, result_arrow_ny)
|
|
|
|
|
|
|
|
self.assertNotEqual(result_ny, result_la)
|
|
|
|
|
|
|
|
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
|
2017-12-26 07:37:25 -05:00
|
|
|
result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v
|
2017-11-28 03:45:22 -05:00
|
|
|
for k, v in row.asDict().items()})
|
|
|
|
for row in result_la]
|
|
|
|
self.assertEqual(result_ny, result_la_corrected)
|
|
|
|
finally:
|
|
|
|
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
|
|
|
|
|
2017-11-12 23:16:01 -05:00
|
|
|
def test_createDataFrame_with_schema(self):
|
|
|
|
pdf = self.create_pandas_data_frame()
|
|
|
|
df = self.spark.createDataFrame(pdf, schema=self.schema)
|
|
|
|
self.assertEquals(self.schema, df.schema)
|
|
|
|
pdf_arrow = df.toPandas()
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(pdf_arrow, pdf)
|
2017-11-12 23:16:01 -05:00
|
|
|
|
|
|
|
def test_createDataFrame_with_incorrect_schema(self):
|
|
|
|
pdf = self.create_pandas_data_frame()
|
|
|
|
wrong_schema = StructType(list(reversed(self.schema)))
|
|
|
|
with QuietTest(self.sc):
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"):
|
2017-11-12 23:16:01 -05:00
|
|
|
self.spark.createDataFrame(pdf, schema=wrong_schema)
|
|
|
|
|
|
|
|
def test_createDataFrame_with_names(self):
|
|
|
|
pdf = self.create_pandas_data_frame()
|
|
|
|
# Test that schema as a list of column names gets applied
|
2017-12-26 07:37:25 -05:00
|
|
|
df = self.spark.createDataFrame(pdf, schema=list('abcdefgh'))
|
|
|
|
self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
|
2017-11-12 23:16:01 -05:00
|
|
|
# Test that schema as tuple of column names gets applied
|
2017-12-26 07:37:25 -05:00
|
|
|
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh'))
|
|
|
|
self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
|
2017-11-15 09:35:13 -05:00
|
|
|
|
|
|
|
def test_createDataFrame_column_name_encoding(self):
|
|
|
|
import pandas as pd
|
|
|
|
pdf = pd.DataFrame({u'a': [1]})
|
|
|
|
columns = self.spark.createDataFrame(pdf).columns
|
|
|
|
self.assertTrue(isinstance(columns[0], str))
|
|
|
|
self.assertEquals(columns[0], 'a')
|
|
|
|
columns = self.spark.createDataFrame(pdf, [u'b']).columns
|
|
|
|
self.assertTrue(isinstance(columns[0], str))
|
|
|
|
self.assertEquals(columns[0], 'b')
|
2017-11-12 23:16:01 -05:00
|
|
|
|
|
|
|
def test_createDataFrame_with_single_data_type(self):
|
|
|
|
import pandas as pd
|
|
|
|
with QuietTest(self.sc):
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"):
|
2017-11-12 23:16:01 -05:00
|
|
|
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
|
|
|
|
|
|
|
|
def test_createDataFrame_does_not_modify_input(self):
|
2017-12-21 06:43:56 -05:00
|
|
|
import pandas as pd
|
2017-11-12 23:16:01 -05:00
|
|
|
# Some series get converted for Spark to consume, this makes sure input is unchanged
|
|
|
|
pdf = self.create_pandas_data_frame()
|
|
|
|
# Use a nanosecond value to make sure it is not truncated
|
2017-12-26 07:37:25 -05:00
|
|
|
pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1)
|
2017-11-12 23:16:01 -05:00
|
|
|
# Integers with nulls will get NaNs filled with 0 and will be casted
|
|
|
|
pdf.ix[1, '2_int_t'] = None
|
|
|
|
pdf_copy = pdf.copy(deep=True)
|
|
|
|
self.spark.createDataFrame(pdf, schema=self.schema)
|
|
|
|
self.assertTrue(pdf.equals(pdf_copy))
|
|
|
|
|
|
|
|
def test_schema_conversion_roundtrip(self):
|
|
|
|
from pyspark.sql.types import from_arrow_schema, to_arrow_schema
|
|
|
|
arrow_schema = to_arrow_schema(self.schema)
|
|
|
|
schema_rt = from_arrow_schema(arrow_schema)
|
|
|
|
self.assertEquals(self.schema, schema_rt)
|
|
|
|
|
2018-01-01 17:13:27 -05:00
|
|
|
def test_createDataFrame_with_array_type(self):
|
|
|
|
import pandas as pd
|
|
|
|
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
|
|
|
|
df, df_arrow = self._createDataFrame_toggle(pdf)
|
|
|
|
result = df.collect()
|
|
|
|
result_arrow = df_arrow.collect()
|
|
|
|
expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
|
|
|
|
for r in range(len(expected)):
|
|
|
|
for e in range(len(expected[r])):
|
|
|
|
self.assertTrue(expected[r][e] == result_arrow[r][e] and
|
|
|
|
result[r][e] == result_arrow[r][e])
|
|
|
|
|
|
|
|
def test_toPandas_with_array_type(self):
|
|
|
|
expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])]
|
|
|
|
array_schema = StructType([StructField("a", ArrayType(IntegerType())),
|
|
|
|
StructField("b", ArrayType(StringType()))])
|
|
|
|
df = self.spark.createDataFrame(expected, schema=array_schema)
|
|
|
|
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
|
|
|
|
result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
|
|
|
|
result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)]
|
|
|
|
for r in range(len(expected)):
|
|
|
|
for e in range(len(expected[r])):
|
|
|
|
self.assertTrue(expected[r][e] == result_arrow[r][e] and
|
|
|
|
result[r][e] == result_arrow[r][e])
|
|
|
|
|
2018-01-10 00:55:24 -05:00
|
|
|
def test_createDataFrame_with_int_col_names(self):
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
pdf = pd.DataFrame(np.random.rand(4, 2))
|
|
|
|
df, df_arrow = self._createDataFrame_toggle(pdf)
|
|
|
|
pdf_col_names = [str(c) for c in pdf.columns]
|
|
|
|
self.assertEqual(pdf_col_names, df.columns)
|
|
|
|
self.assertEqual(pdf_col_names, df_arrow.columns)
|
|
|
|
|
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request?
This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame.
## How was this patch tested?
Manually tested and unit tests added.
You can test this by:
**`createDataFrame`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame(pdf, "a: map<string, int>")
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", False)
pdf = spark.createDataFrame([[{'a': 1}]]).toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame(pdf, "a: map<string, int>")
```
**`toPandas`**
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
```python
spark.conf.set("spark.sql.execution.arrow.enabled", True)
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False)
spark.createDataFrame([[{'a': 1}]]).toPandas()
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20678 from HyukjinKwon/SPARK-23380-conf.
2018-03-08 06:22:07 -05:00
|
|
|
def test_createDataFrame_fallback_enabled(self):
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
|
|
|
|
with warnings.catch_warnings(record=True) as warns:
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
|
|
|
|
# Catch and check the last UserWarning.
|
|
|
|
user_warns = [
|
|
|
|
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
|
|
|
|
self.assertTrue(len(user_warns) > 0)
|
|
|
|
self.assertTrue(
|
|
|
|
"Attempts non-optimization" in _exception_message(user_warns[-1]))
|
|
|
|
self.assertEqual(df.collect(), [Row(a={u'a': 1})])
|
|
|
|
|
|
|
|
def test_createDataFrame_fallback_disabled(self):
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
|
|
|
|
|
2018-02-11 03:31:35 -05:00
|
|
|
# Regression test for SPARK-23314
|
|
|
|
def test_timestamp_dst(self):
|
|
|
|
import pandas as pd
|
|
|
|
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
|
|
|
|
dt = [datetime.datetime(2015, 11, 1, 0, 30),
|
|
|
|
datetime.datetime(2015, 11, 1, 1, 30),
|
|
|
|
datetime.datetime(2015, 11, 1, 2, 30)]
|
|
|
|
pdf = pd.DataFrame({'time': dt})
|
|
|
|
|
|
|
|
df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
|
|
|
|
df_from_pandas = self.spark.createDataFrame(pdf)
|
|
|
|
|
|
|
|
self.assertPandasEqual(pdf, df_from_python.toPandas())
|
|
|
|
self.assertPandasEqual(pdf, df_from_pandas.toPandas())
|
|
|
|
|
2017-07-10 18:21:03 -04:00
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(
|
|
|
|
not _have_pandas or not _have_pyarrow,
|
|
|
|
_pandas_requirement_message or _pyarrow_requirement_message)
|
2017-11-17 10:43:08 -05:00
|
|
|
class PandasUDFTests(ReusedSQLTestCase):
|
|
|
|
def test_pandas_udf_basic(self):
|
|
|
|
from pyspark.rdd import PythonEvalType
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
|
|
|
|
udf = pandas_udf(lambda x: x, DoubleType())
|
|
|
|
self.assertEqual(udf.returnType, DoubleType())
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
|
|
|
|
self.assertEqual(udf.returnType, DoubleType())
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR)
|
|
|
|
self.assertEqual(udf.returnType, DoubleType())
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]),
|
2018-01-30 07:55:55 -05:00
|
|
|
PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
udf = pandas_udf(lambda x: x, 'v double',
|
2018-01-30 07:55:55 -05:00
|
|
|
functionType=PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
udf = pandas_udf(lambda x: x, returnType='v double',
|
2018-01-30 07:55:55 -05:00
|
|
|
functionType=PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
def test_pandas_udf_decorator(self):
|
|
|
|
from pyspark.rdd import PythonEvalType
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
from pyspark.sql.types import StructType, StructField, DoubleType
|
|
|
|
|
|
|
|
@pandas_udf(DoubleType())
|
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
self.assertEqual(foo.returnType, DoubleType())
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
@pandas_udf(returnType=DoubleType())
|
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
self.assertEqual(foo.returnType, DoubleType())
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
schema = StructType([StructField("v", DoubleType())])
|
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
self.assertEqual(foo.returnType, schema)
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf('v double', PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
self.assertEqual(foo.returnType, schema)
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
self.assertEqual(foo.returnType, schema)
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
@pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
|
2017-11-17 10:43:08 -05:00
|
|
|
def foo(x):
|
|
|
|
return x
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
self.assertEqual(foo.returnType, DoubleType())
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
self.assertEqual(foo.returnType, schema)
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
def test_udf_wrong_arg(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaises(ParseException):
|
|
|
|
@pandas_udf('blah')
|
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'):
|
|
|
|
@pandas_udf(functionType=PandasUDFType.SCALAR)
|
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid functionType'):
|
|
|
|
@pandas_udf('double', 100)
|
|
|
|
def foo(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
|
|
|
|
pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR)
|
|
|
|
with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
|
|
|
|
@pandas_udf(LongType(), PandasUDFType.SCALAR)
|
|
|
|
def zero_with_type():
|
|
|
|
return 1
|
|
|
|
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
def foo(df):
|
|
|
|
return df
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
|
2017-11-17 10:43:08 -05:00
|
|
|
def foo(df):
|
|
|
|
return df
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
|
2018-03-08 06:29:07 -05:00
|
|
|
def foo(k, v, w):
|
2017-11-17 10:43:08 -05:00
|
|
|
return k
|
|
|
|
|
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(
|
|
|
|
not _have_pandas or not _have_pyarrow,
|
|
|
|
_pandas_requirement_message or _pyarrow_requirement_message)
|
2018-02-06 15:30:04 -05:00
|
|
|
class ScalarPandasUDFTests(ReusedSQLTestCase):
|
2017-09-22 04:17:41 -04:00
|
|
|
|
2017-11-28 03:45:22 -05:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ReusedSQLTestCase.setUpClass()
|
|
|
|
|
|
|
|
# Synchronize default timezone between Python and Java
|
|
|
|
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
|
|
|
|
tz = "America/Los_Angeles"
|
|
|
|
os.environ["TZ"] = tz
|
|
|
|
time.tzset()
|
|
|
|
|
|
|
|
cls.sc.environment["TZ"] = tz
|
|
|
|
cls.spark.conf.set("spark.sql.session.timeZone", tz)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
del os.environ["TZ"]
|
|
|
|
if cls.tz_prev is not None:
|
|
|
|
os.environ["TZ"] = cls.tz_prev
|
|
|
|
time.tzset()
|
|
|
|
ReusedSQLTestCase.tearDownClass()
|
|
|
|
|
2018-01-06 03:11:20 -05:00
|
|
|
@property
|
2018-01-16 06:20:33 -05:00
|
|
|
def nondeterministic_vectorized_udf(self):
|
2018-01-06 03:11:20 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf
|
|
|
|
|
|
|
|
@pandas_udf('double')
|
|
|
|
def random_udf(v):
|
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
return pd.Series(np.random.random(len(v)))
|
|
|
|
random_udf = random_udf.asNondeterministic()
|
|
|
|
return random_udf
|
|
|
|
|
2017-09-22 04:17:41 -04:00
|
|
|
def test_vectorized_udf_basic(self):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, col, array
|
2017-09-22 04:17:41 -04:00
|
|
|
df = self.spark.range(10).select(
|
|
|
|
col('id').cast('string').alias('str'),
|
|
|
|
col('id').cast('int').alias('int'),
|
|
|
|
col('id').alias('long'),
|
|
|
|
col('id').cast('float').alias('float'),
|
|
|
|
col('id').cast('double').alias('double'),
|
2017-12-26 07:37:25 -05:00
|
|
|
col('id').cast('decimal').alias('decimal'),
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
col('id').cast('boolean').alias('bool'),
|
|
|
|
array(col('id')).alias('array_long'))
|
2017-09-22 04:17:41 -04:00
|
|
|
f = lambda x: x
|
|
|
|
str_f = pandas_udf(f, StringType())
|
|
|
|
int_f = pandas_udf(f, IntegerType())
|
|
|
|
long_f = pandas_udf(f, LongType())
|
|
|
|
float_f = pandas_udf(f, FloatType())
|
|
|
|
double_f = pandas_udf(f, DoubleType())
|
2017-12-26 07:37:25 -05:00
|
|
|
decimal_f = pandas_udf(f, DecimalType())
|
2017-09-22 04:17:41 -04:00
|
|
|
bool_f = pandas_udf(f, BooleanType())
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
array_long_f = pandas_udf(f, ArrayType(LongType()))
|
2017-09-22 04:17:41 -04:00
|
|
|
res = df.select(str_f(col('str')), int_f(col('int')),
|
|
|
|
long_f(col('long')), float_f(col('float')),
|
2017-12-26 07:37:25 -05:00
|
|
|
double_f(col('double')), decimal_f('decimal'),
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
bool_f(col('bool')), array_long_f('array_long'))
|
2017-09-22 04:17:41 -04:00
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
2018-01-16 06:20:33 -05:00
|
|
|
def test_register_nondeterministic_vectorized_udf_basic(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf
|
|
|
|
from pyspark.rdd import PythonEvalType
|
|
|
|
import random
|
|
|
|
random_pandas_udf = pandas_udf(
|
|
|
|
lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic()
|
|
|
|
self.assertEqual(random_pandas_udf.deterministic, False)
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2018-01-16 06:20:33 -05:00
|
|
|
nondeterministic_pandas_udf = self.spark.catalog.registerFunction(
|
|
|
|
"randomPandasUDF", random_pandas_udf)
|
|
|
|
self.assertEqual(nondeterministic_pandas_udf.deterministic, False)
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2018-01-16 06:20:33 -05:00
|
|
|
[row] = self.spark.sql("SELECT randomPandasUDF(1)").collect()
|
|
|
|
self.assertEqual(row[0], 7)
|
|
|
|
|
2017-09-22 04:17:41 -04:00
|
|
|
def test_vectorized_udf_null_boolean(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [(True,), (True,), (None,), (False,)]
|
|
|
|
schema = StructType().add("bool", BooleanType())
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
bool_f = pandas_udf(lambda x: x, BooleanType())
|
|
|
|
res = df.select(bool_f(col('bool')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_null_byte(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [(None,), (2,), (3,), (4,)]
|
|
|
|
schema = StructType().add("byte", ByteType())
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
byte_f = pandas_udf(lambda x: x, ByteType())
|
|
|
|
res = df.select(byte_f(col('byte')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_null_short(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [(None,), (2,), (3,), (4,)]
|
|
|
|
schema = StructType().add("short", ShortType())
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
short_f = pandas_udf(lambda x: x, ShortType())
|
|
|
|
res = df.select(short_f(col('short')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_null_int(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [(None,), (2,), (3,), (4,)]
|
|
|
|
schema = StructType().add("int", IntegerType())
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
int_f = pandas_udf(lambda x: x, IntegerType())
|
|
|
|
res = df.select(int_f(col('int')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_null_long(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [(None,), (2,), (3,), (4,)]
|
|
|
|
schema = StructType().add("long", LongType())
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
long_f = pandas_udf(lambda x: x, LongType())
|
|
|
|
res = df.select(long_f(col('long')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_null_float(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [(3.0,), (5.0,), (-1.0,), (None,)]
|
|
|
|
schema = StructType().add("float", FloatType())
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
float_f = pandas_udf(lambda x: x, FloatType())
|
|
|
|
res = df.select(float_f(col('float')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_null_double(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [(3.0,), (5.0,), (-1.0,), (None,)]
|
|
|
|
schema = StructType().add("double", DoubleType())
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
double_f = pandas_udf(lambda x: x, DoubleType())
|
|
|
|
res = df.select(double_f(col('double')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
2017-12-26 07:37:25 -05:00
|
|
|
def test_vectorized_udf_null_decimal(self):
|
|
|
|
from decimal import Decimal
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
|
|
|
|
schema = StructType().add("decimal", DecimalType(38, 18))
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18))
|
|
|
|
res = df.select(decimal_f(col('decimal')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
2017-09-22 04:17:41 -04:00
|
|
|
def test_vectorized_udf_null_string(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [("foo",), (None,), ("bar",), ("bar",)]
|
|
|
|
schema = StructType().add("str", StringType())
|
|
|
|
df = self.spark.createDataFrame(data, schema)
|
|
|
|
str_f = pandas_udf(lambda x: x, StringType())
|
|
|
|
res = df.select(str_f(col('str')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
2018-02-06 04:30:50 -05:00
|
|
|
|
|
|
|
def test_vectorized_udf_string_in_udf(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
import pandas as pd
|
|
|
|
df = self.spark.range(10)
|
|
|
|
str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType())
|
|
|
|
actual = df.select(str_f(col('id')))
|
|
|
|
expected = df.select(col('id').cast('string'))
|
|
|
|
self.assertEquals(expected.collect(), actual.collect())
|
2017-09-22 04:17:41 -04:00
|
|
|
|
|
|
|
def test_vectorized_udf_datatype_string(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
df = self.spark.range(10).select(
|
|
|
|
col('id').cast('string').alias('str'),
|
|
|
|
col('id').cast('int').alias('int'),
|
|
|
|
col('id').alias('long'),
|
|
|
|
col('id').cast('float').alias('float'),
|
|
|
|
col('id').cast('double').alias('double'),
|
2017-12-26 07:37:25 -05:00
|
|
|
col('id').cast('decimal').alias('decimal'),
|
2017-09-22 04:17:41 -04:00
|
|
|
col('id').cast('boolean').alias('bool'))
|
|
|
|
f = lambda x: x
|
|
|
|
str_f = pandas_udf(f, 'string')
|
|
|
|
int_f = pandas_udf(f, 'integer')
|
|
|
|
long_f = pandas_udf(f, 'long')
|
|
|
|
float_f = pandas_udf(f, 'float')
|
|
|
|
double_f = pandas_udf(f, 'double')
|
2017-12-26 07:37:25 -05:00
|
|
|
decimal_f = pandas_udf(f, 'decimal(38, 18)')
|
2017-09-22 04:17:41 -04:00
|
|
|
bool_f = pandas_udf(f, 'boolean')
|
|
|
|
res = df.select(str_f(col('str')), int_f(col('int')),
|
|
|
|
long_f(col('long')), float_f(col('float')),
|
2017-12-26 07:37:25 -05:00
|
|
|
double_f(col('double')), decimal_f('decimal'),
|
|
|
|
bool_f(col('bool')))
|
2017-09-22 04:17:41 -04:00
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
2018-01-01 17:13:27 -05:00
|
|
|
def test_vectorized_udf_array_type(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [([1, 2],), ([3, 4],)]
|
|
|
|
array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
|
|
|
|
df = self.spark.createDataFrame(data, schema=array_schema)
|
|
|
|
array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
|
|
|
|
result = df.select(array_f(col('array')))
|
|
|
|
self.assertEquals(df.collect(), result.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_null_array(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
|
|
|
|
array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
|
|
|
|
df = self.spark.createDataFrame(data, schema=array_schema)
|
|
|
|
array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
|
|
|
|
result = df.select(array_f(col('array')))
|
|
|
|
self.assertEquals(df.collect(), result.collect())
|
|
|
|
|
2017-09-22 04:17:41 -04:00
|
|
|
def test_vectorized_udf_complex(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col, expr
|
|
|
|
df = self.spark.range(10).select(
|
|
|
|
col('id').cast('int').alias('a'),
|
|
|
|
col('id').cast('int').alias('b'),
|
|
|
|
col('id').cast('double').alias('c'))
|
|
|
|
add = pandas_udf(lambda x, y: x + y, IntegerType())
|
|
|
|
power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
|
|
|
|
mul = pandas_udf(lambda x, y: x * y, DoubleType())
|
|
|
|
res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
|
|
|
|
expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
|
|
|
|
self.assertEquals(expected.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_exception(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
df = self.spark.range(10)
|
|
|
|
raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
|
|
|
|
df.select(raise_exception(col('id'))).collect()
|
|
|
|
|
|
|
|
def test_vectorized_udf_invalid_length(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
import pandas as pd
|
|
|
|
df = self.spark.range(10)
|
2017-09-25 21:54:00 -04:00
|
|
|
raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
|
2017-09-22 04:17:41 -04:00
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
Exception,
|
|
|
|
'Result vector from pandas_udf was not the required length'):
|
2017-09-25 21:54:00 -04:00
|
|
|
df.select(raise_exception(col('id'))).collect()
|
2017-09-22 04:17:41 -04:00
|
|
|
|
|
|
|
def test_vectorized_udf_mix_udf(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, udf, col
|
|
|
|
df = self.spark.range(10)
|
|
|
|
row_by_row_udf = udf(lambda x: x, LongType())
|
|
|
|
pd_udf = pandas_udf(lambda x: x, LongType())
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
Exception,
|
|
|
|
'Can not mix vectorized and non-vectorized UDFs'):
|
|
|
|
df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()
|
|
|
|
|
|
|
|
def test_vectorized_udf_chained(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
2017-09-25 21:54:00 -04:00
|
|
|
df = self.spark.range(10)
|
2017-09-22 04:17:41 -04:00
|
|
|
f = pandas_udf(lambda x: x + 1, LongType())
|
|
|
|
g = pandas_udf(lambda x: x - 1, LongType())
|
2017-09-25 21:54:00 -04:00
|
|
|
res = df.select(g(f(col('id'))))
|
2017-09-22 04:17:41 -04:00
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_wrong_return_type(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
2017-09-25 21:54:00 -04:00
|
|
|
df = self.spark.range(10)
|
2017-09-22 04:17:41 -04:00
|
|
|
with QuietTest(self.sc):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
NotImplementedError,
|
|
|
|
'Invalid returnType.*scalar Pandas UDF.*MapType'):
|
|
|
|
pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
|
2017-09-25 21:54:00 -04:00
|
|
|
|
|
|
|
def test_vectorized_udf_return_scalar(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
df = self.spark.range(10)
|
|
|
|
f = pandas_udf(lambda x: 1.0, DoubleType())
|
|
|
|
with QuietTest(self.sc):
|
2017-10-10 18:32:01 -04:00
|
|
|
with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'):
|
2017-09-25 21:54:00 -04:00
|
|
|
df.select(f(col('id'))).collect()
|
|
|
|
|
|
|
|
def test_vectorized_udf_decorator(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
df = self.spark.range(10)
|
2017-09-22 04:17:41 -04:00
|
|
|
|
2017-09-25 21:54:00 -04:00
|
|
|
@pandas_udf(returnType=LongType())
|
|
|
|
def identity(x):
|
|
|
|
return x
|
|
|
|
res = df.select(identity(col('id')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
|
|
|
def test_vectorized_udf_empty_partition(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
|
|
|
|
f = pandas_udf(lambda x: x, LongType())
|
|
|
|
res = df.select(f(col('id')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
2017-09-22 04:17:41 -04:00
|
|
|
|
2017-10-10 18:32:01 -04:00
|
|
|
def test_vectorized_udf_varargs(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
|
|
|
|
f = pandas_udf(lambda *v: v[0], LongType())
|
|
|
|
res = df.select(f(col('id')))
|
|
|
|
self.assertEquals(df.collect(), res.collect())
|
|
|
|
|
2017-10-20 15:44:30 -04:00
|
|
|
def test_vectorized_udf_unsupported_types(self):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf
|
2017-10-20 15:44:30 -04:00
|
|
|
with QuietTest(self.sc):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
NotImplementedError,
|
|
|
|
'Invalid returnType.*scalar Pandas UDF.*MapType'):
|
|
|
|
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
NotImplementedError,
|
|
|
|
'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
|
|
|
|
pandas_udf(lambda x: x, BinaryType())
|
2017-10-20 15:44:30 -04:00
|
|
|
|
2018-02-06 01:52:25 -05:00
|
|
|
def test_vectorized_udf_dates(self):
|
2017-10-27 02:02:46 -04:00
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
from datetime import date
|
2018-02-06 01:52:25 -05:00
|
|
|
schema = StructType().add("idx", LongType()).add("date", DateType())
|
|
|
|
data = [(0, date(1969, 1, 1),),
|
|
|
|
(1, date(2012, 2, 2),),
|
|
|
|
(2, None,),
|
|
|
|
(3, date(2100, 4, 4),)]
|
2017-10-27 02:02:46 -04:00
|
|
|
df = self.spark.createDataFrame(data, schema=schema)
|
2018-02-06 01:52:25 -05:00
|
|
|
|
|
|
|
date_copy = pandas_udf(lambda t: t, returnType=DateType())
|
|
|
|
df = df.withColumn("date_copy", date_copy(col("date")))
|
|
|
|
|
|
|
|
@pandas_udf(returnType=StringType())
|
|
|
|
def check_data(idx, date, date_copy):
|
|
|
|
import pandas as pd
|
|
|
|
msgs = []
|
|
|
|
is_equal = date.isnull()
|
|
|
|
for i in range(len(idx)):
|
|
|
|
if (is_equal[i] and data[idx[i]][1] is None) or \
|
|
|
|
date[i] == data[idx[i]][1]:
|
|
|
|
msgs.append(None)
|
|
|
|
else:
|
|
|
|
msgs.append(
|
|
|
|
"date values are not equal (date='%s': data[%d][1]='%s')"
|
|
|
|
% (date[i], idx[i], data[idx[i]][1]))
|
|
|
|
return pd.Series(msgs)
|
|
|
|
|
|
|
|
result = df.withColumn("check_data",
|
|
|
|
check_data(col("idx"), col("date"), col("date_copy"))).collect()
|
|
|
|
|
|
|
|
self.assertEquals(len(data), len(result))
|
|
|
|
for i in range(len(result)):
|
|
|
|
self.assertEquals(data[i][1], result[i][1]) # "date" col
|
|
|
|
self.assertEquals(data[i][1], result[i][2]) # "date_copy" col
|
|
|
|
self.assertIsNone(result[i][3]) # "check_data" col
|
2017-10-27 02:02:46 -04:00
|
|
|
|
|
|
|
def test_vectorized_udf_timestamps(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
from datetime import datetime
|
|
|
|
schema = StructType([
|
|
|
|
StructField("idx", LongType(), True),
|
|
|
|
StructField("timestamp", TimestampType(), True)])
|
|
|
|
data = [(0, datetime(1969, 1, 1, 1, 1, 1)),
|
|
|
|
(1, datetime(2012, 2, 2, 2, 2, 2)),
|
|
|
|
(2, None),
|
2017-11-28 03:45:22 -05:00
|
|
|
(3, datetime(2100, 3, 3, 3, 3, 3))]
|
|
|
|
|
2017-10-27 02:02:46 -04:00
|
|
|
df = self.spark.createDataFrame(data, schema=schema)
|
|
|
|
|
|
|
|
# Check that a timestamp passed through a pandas_udf will not be altered by timezone calc
|
|
|
|
f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType())
|
|
|
|
df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))
|
|
|
|
|
2017-11-28 03:45:22 -05:00
|
|
|
@pandas_udf(returnType=StringType())
|
2017-10-27 02:02:46 -04:00
|
|
|
def check_data(idx, timestamp, timestamp_copy):
|
2017-11-28 03:45:22 -05:00
|
|
|
import pandas as pd
|
|
|
|
msgs = []
|
2017-10-27 02:02:46 -04:00
|
|
|
is_equal = timestamp.isnull() # use this array to check values are equal
|
|
|
|
for i in range(len(idx)):
|
|
|
|
# Check that timestamps are as expected in the UDF
|
2017-11-28 03:45:22 -05:00
|
|
|
if (is_equal[i] and data[idx[i]][1] is None) or \
|
|
|
|
timestamp[i].to_pydatetime() == data[idx[i]][1]:
|
|
|
|
msgs.append(None)
|
|
|
|
else:
|
|
|
|
msgs.append(
|
|
|
|
"timestamp values are not equal (timestamp='%s': data[%d][1]='%s')"
|
|
|
|
% (timestamp[i], idx[i], data[idx[i]][1]))
|
|
|
|
return pd.Series(msgs)
|
|
|
|
|
|
|
|
result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
|
|
|
|
col("timestamp_copy"))).collect()
|
2017-10-27 02:02:46 -04:00
|
|
|
# Check that collection values are correct
|
|
|
|
self.assertEquals(len(data), len(result))
|
|
|
|
for i in range(len(result)):
|
|
|
|
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
|
2018-02-06 01:52:25 -05:00
|
|
|
self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col
|
2017-11-28 03:45:22 -05:00
|
|
|
self.assertIsNone(result[i][3]) # "check_data" col
|
2017-10-27 02:02:46 -04:00
|
|
|
|
|
|
|
def test_vectorized_udf_return_timestamp_tz(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
import pandas as pd
|
|
|
|
df = self.spark.range(10)
|
|
|
|
|
|
|
|
@pandas_udf(returnType=TimestampType())
|
|
|
|
def gen_timestamps(id):
|
|
|
|
ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
|
|
|
|
return pd.Series(ts)
|
|
|
|
|
|
|
|
result = df.withColumn("ts", gen_timestamps(col("id"))).collect()
|
|
|
|
spark_ts_t = TimestampType()
|
|
|
|
for r in result:
|
|
|
|
i, ts = r
|
|
|
|
ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
|
|
|
|
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
|
|
|
|
self.assertEquals(expected, ts)
|
|
|
|
|
2017-10-28 13:33:09 -04:00
|
|
|
def test_vectorized_udf_check_config(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
2017-12-29 09:04:28 -05:00
|
|
|
import pandas as pd
|
2017-10-28 13:33:09 -04:00
|
|
|
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
|
|
|
|
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
|
|
|
|
try:
|
|
|
|
df = self.spark.range(10, numPartitions=1)
|
|
|
|
|
|
|
|
@pandas_udf(returnType=LongType())
|
|
|
|
def check_records_per_batch(x):
|
2017-12-29 09:04:28 -05:00
|
|
|
return pd.Series(x.size).repeat(x.size)
|
2017-10-28 13:33:09 -04:00
|
|
|
|
2017-12-29 09:04:28 -05:00
|
|
|
result = df.select(check_records_per_batch(col("id"))).collect()
|
|
|
|
for (r,) in result:
|
|
|
|
self.assertTrue(r <= 3)
|
2017-10-28 13:33:09 -04:00
|
|
|
finally:
|
|
|
|
if orig_value is None:
|
|
|
|
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
|
|
|
|
else:
|
|
|
|
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
|
|
|
|
|
2017-11-28 03:45:22 -05:00
|
|
|
def test_vectorized_udf_timestamps_respect_session_timezone(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col
|
|
|
|
from datetime import datetime
|
|
|
|
import pandas as pd
|
|
|
|
schema = StructType([
|
|
|
|
StructField("idx", LongType(), True),
|
|
|
|
StructField("timestamp", TimestampType(), True)])
|
|
|
|
data = [(1, datetime(1969, 1, 1, 1, 1, 1)),
|
|
|
|
(2, datetime(2012, 2, 2, 2, 2, 2)),
|
|
|
|
(3, None),
|
|
|
|
(4, datetime(2100, 3, 3, 3, 3, 3))]
|
|
|
|
df = self.spark.createDataFrame(data, schema=schema)
|
|
|
|
|
|
|
|
f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType())
|
|
|
|
internal_value = pandas_udf(
|
|
|
|
lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())
|
|
|
|
|
|
|
|
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
|
|
|
|
try:
|
|
|
|
timezone = "America/New_York"
|
|
|
|
self.spark.conf.set("spark.sql.session.timeZone", timezone)
|
|
|
|
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
|
|
|
|
try:
|
|
|
|
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
|
|
|
|
.withColumn("internal_value", internal_value(col("timestamp")))
|
|
|
|
result_la = df_la.select(col("idx"), col("internal_value")).collect()
|
|
|
|
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
|
|
|
|
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
|
|
|
|
result_la_corrected = \
|
|
|
|
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
|
|
|
|
finally:
|
|
|
|
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
|
|
|
|
|
|
|
|
df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
|
|
|
|
.withColumn("internal_value", internal_value(col("timestamp")))
|
|
|
|
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()
|
|
|
|
|
|
|
|
self.assertNotEqual(result_ny, result_la)
|
|
|
|
self.assertEqual(result_ny, result_la_corrected)
|
|
|
|
finally:
|
|
|
|
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
|
|
|
|
|
2018-01-16 06:20:33 -05:00
|
|
|
def test_nondeterministic_vectorized_udf(self):
|
2018-01-06 03:11:20 -05:00
|
|
|
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
|
|
|
|
from pyspark.sql.functions import udf, pandas_udf, col
|
|
|
|
|
|
|
|
@pandas_udf('double')
|
|
|
|
def plus_ten(v):
|
|
|
|
return v + 10
|
2018-01-16 06:20:33 -05:00
|
|
|
random_udf = self.nondeterministic_vectorized_udf
|
2018-01-06 03:11:20 -05:00
|
|
|
|
|
|
|
df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
|
|
|
|
result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas()
|
|
|
|
|
|
|
|
self.assertEqual(random_udf.deterministic, False)
|
|
|
|
self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10))
|
|
|
|
|
2018-01-16 06:20:33 -05:00
|
|
|
def test_nondeterministic_vectorized_udf_in_aggregate(self):
|
2018-01-06 03:11:20 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, sum
|
|
|
|
|
|
|
|
df = self.spark.range(10)
|
2018-01-16 06:20:33 -05:00
|
|
|
random_udf = self.nondeterministic_vectorized_udf
|
2018-01-06 03:11:20 -05:00
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
|
|
|
|
df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
|
|
|
|
with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
|
|
|
|
df.agg(sum(random_udf(df.id))).collect()
|
|
|
|
|
2018-01-16 06:20:33 -05:00
|
|
|
def test_register_vectorized_udf_basic(self):
|
|
|
|
from pyspark.rdd import PythonEvalType
|
|
|
|
from pyspark.sql.functions import pandas_udf, col, expr
|
|
|
|
df = self.spark.range(10).select(
|
|
|
|
col('id').cast('int').alias('a'),
|
|
|
|
col('id').cast('int').alias('b'))
|
|
|
|
original_add = pandas_udf(lambda x, y: x + y, IntegerType())
|
|
|
|
self.assertEqual(original_add.deterministic, True)
|
2018-01-30 07:55:55 -05:00
|
|
|
self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
|
2018-01-16 06:20:33 -05:00
|
|
|
new_add = self.spark.catalog.registerFunction("add1", original_add)
|
|
|
|
res1 = df.select(new_add(col('a'), col('b')))
|
|
|
|
res2 = self.spark.sql(
|
|
|
|
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t")
|
|
|
|
expected = df.select(expr('a + b'))
|
|
|
|
self.assertEquals(expected.collect(), res1.collect())
|
|
|
|
self.assertEquals(expected.collect(), res2.collect())
|
|
|
|
|
2018-02-11 03:31:35 -05:00
|
|
|
# Regression test for SPARK-23314
|
|
|
|
def test_timestamp_dst(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf
|
|
|
|
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
|
|
|
|
dt = [datetime.datetime(2015, 11, 1, 0, 30),
|
|
|
|
datetime.datetime(2015, 11, 1, 1, 30),
|
|
|
|
datetime.datetime(2015, 11, 1, 2, 30)]
|
|
|
|
df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
|
|
|
|
foo_udf = pandas_udf(lambda x: x, 'timestamp')
|
|
|
|
result = df.withColumn('time', foo_udf(df.time))
|
|
|
|
self.assertEquals(df.collect(), result.collect())
|
|
|
|
|
2018-03-04 23:36:42 -05:00
|
|
|
@unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
|
|
|
|
def test_type_annotation(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf
|
|
|
|
# Regression test to check if type hints can be used. See SPARK-23569.
|
|
|
|
# Note that it throws an error during compilation in lower Python versions if 'exec'
|
|
|
|
# is not used. Also, note that we explicitly use another dictionary to avoid modifications
|
|
|
|
# in the current 'locals()'.
|
|
|
|
#
|
|
|
|
# Hyukjin: I think it's an ugly way to test issues about syntax specific in
|
|
|
|
# higher versions of Python, which we shouldn't encourage. This was the last resort
|
|
|
|
# I could come up with at that time.
|
|
|
|
_locals = {}
|
|
|
|
exec(
|
|
|
|
"import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col",
|
|
|
|
_locals)
|
|
|
|
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
|
|
|
|
self.assertEqual(df.first()[0], 0)
|
|
|
|
|
2017-10-10 18:32:01 -04:00
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(
|
|
|
|
not _have_pandas or not _have_pyarrow,
|
|
|
|
_pandas_requirement_message or _pyarrow_requirement_message)
|
2018-02-06 15:30:04 -05:00
|
|
|
class GroupedMapPandasUDFTests(ReusedSQLTestCase):
|
2017-10-10 18:32:01 -04:00
|
|
|
|
|
|
|
@property
|
|
|
|
def data(self):
|
|
|
|
from pyspark.sql.functions import array, explode, col, lit
|
|
|
|
return self.spark.range(10).toDF('id') \
|
|
|
|
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
|
|
|
|
.withColumn("v", explode(col('vs'))).drop('vs')
|
|
|
|
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
def test_supported_types(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
|
|
|
|
df = self.data.withColumn("arr", array(col("id")))
|
2017-10-10 18:32:01 -04:00
|
|
|
|
2018-03-08 06:29:07 -05:00
|
|
|
# Different forms of group map pandas UDF, results of these are the same
|
|
|
|
|
|
|
|
output_schema = StructType(
|
|
|
|
[StructField('id', LongType()),
|
|
|
|
StructField('v', IntegerType()),
|
|
|
|
StructField('arr', ArrayType(LongType())),
|
|
|
|
StructField('v1', DoubleType()),
|
|
|
|
StructField('v2', LongType())])
|
|
|
|
|
|
|
|
udf1 = pandas_udf(
|
2017-10-10 18:32:01 -04:00
|
|
|
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
2018-03-08 06:29:07 -05:00
|
|
|
output_schema,
|
2018-01-30 07:55:55 -05:00
|
|
|
PandasUDFType.GROUPED_MAP
|
2017-11-17 10:43:08 -05:00
|
|
|
)
|
2017-10-10 18:32:01 -04:00
|
|
|
|
2018-03-08 06:29:07 -05:00
|
|
|
udf2 = pandas_udf(
|
|
|
|
lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
|
|
|
output_schema,
|
|
|
|
PandasUDFType.GROUPED_MAP
|
|
|
|
)
|
|
|
|
|
|
|
|
udf3 = pandas_udf(
|
|
|
|
lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
|
|
|
output_schema,
|
|
|
|
PandasUDFType.GROUPED_MAP
|
|
|
|
)
|
|
|
|
|
|
|
|
result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
|
|
|
|
expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
|
|
|
|
|
|
|
|
result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
|
|
|
|
expected2 = expected1
|
|
|
|
|
|
|
|
result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
|
|
|
|
expected3 = expected1
|
|
|
|
|
|
|
|
self.assertPandasEqual(expected1, result1)
|
|
|
|
self.assertPandasEqual(expected2, result2)
|
|
|
|
self.assertPandasEqual(expected3, result3)
|
2017-10-10 18:32:01 -04:00
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
def test_register_grouped_map_udf(self):
|
2018-01-16 06:20:33 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
|
2018-01-16 06:20:33 -05:00
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or '
|
2018-01-30 07:55:55 -05:00
|
|
|
'SQL_SCALAR_PANDAS_UDF'):
|
2018-01-16 06:20:33 -05:00
|
|
|
self.spark.catalog.registerFunction("foo_udf", foo_udf)
|
|
|
|
|
2017-10-10 18:32:01 -04:00
|
|
|
def test_decorator(self):
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
2017-10-10 18:32:01 -04:00
|
|
|
df = self.data
|
|
|
|
|
2017-11-17 10:43:08 -05:00
|
|
|
@pandas_udf(
|
|
|
|
'id long, v int, v1 double, v2 long',
|
2018-01-30 07:55:55 -05:00
|
|
|
PandasUDFType.GROUPED_MAP
|
2017-11-17 10:43:08 -05:00
|
|
|
)
|
2017-10-10 18:32:01 -04:00
|
|
|
def foo(pdf):
|
|
|
|
return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id)
|
|
|
|
|
|
|
|
result = df.groupby('id').apply(foo).sort('id').toPandas()
|
|
|
|
expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(expected, result)
|
2017-10-10 18:32:01 -04:00
|
|
|
|
|
|
|
def test_coerce(self):
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
2017-10-10 18:32:01 -04:00
|
|
|
df = self.data
|
|
|
|
|
|
|
|
foo = pandas_udf(
|
|
|
|
lambda pdf: pdf,
|
2017-11-17 10:43:08 -05:00
|
|
|
'id long, v double',
|
2018-01-30 07:55:55 -05:00
|
|
|
PandasUDFType.GROUPED_MAP
|
2017-11-17 10:43:08 -05:00
|
|
|
)
|
2017-10-10 18:32:01 -04:00
|
|
|
|
|
|
|
result = df.groupby('id').apply(foo).sort('id').toPandas()
|
|
|
|
expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
|
|
|
|
expected = expected.assign(v=expected.v.astype('float64'))
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(expected, result)
|
2017-10-10 18:32:01 -04:00
|
|
|
|
|
|
|
def test_complex_groupby(self):
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
|
2017-10-10 18:32:01 -04:00
|
|
|
df = self.data
|
|
|
|
|
2017-11-17 10:43:08 -05:00
|
|
|
@pandas_udf(
|
|
|
|
'id long, v int, norm double',
|
2018-01-30 07:55:55 -05:00
|
|
|
PandasUDFType.GROUPED_MAP
|
2017-11-17 10:43:08 -05:00
|
|
|
)
|
2017-10-10 18:32:01 -04:00
|
|
|
def normalize(pdf):
|
|
|
|
v = pdf.v
|
|
|
|
return pdf.assign(norm=(v - v.mean()) / v.std())
|
|
|
|
|
|
|
|
result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas()
|
|
|
|
pdf = df.toPandas()
|
|
|
|
expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func)
|
|
|
|
expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
|
|
|
|
expected = expected.assign(norm=expected.norm.astype('float64'))
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(expected, result)
|
2017-10-10 18:32:01 -04:00
|
|
|
|
|
|
|
def test_empty_groupby(self):
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
|
2017-10-10 18:32:01 -04:00
|
|
|
df = self.data
|
|
|
|
|
2017-11-17 10:43:08 -05:00
|
|
|
@pandas_udf(
|
|
|
|
'id long, v int, norm double',
|
2018-01-30 07:55:55 -05:00
|
|
|
PandasUDFType.GROUPED_MAP
|
2017-11-17 10:43:08 -05:00
|
|
|
)
|
2017-10-10 18:32:01 -04:00
|
|
|
def normalize(pdf):
|
|
|
|
v = pdf.v
|
|
|
|
return pdf.assign(norm=(v - v.mean()) / v.std())
|
|
|
|
|
|
|
|
result = df.groupby().apply(normalize).sort('id', 'v').toPandas()
|
|
|
|
pdf = df.toPandas()
|
|
|
|
expected = normalize.func(pdf)
|
|
|
|
expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
|
|
|
|
expected = expected.assign(norm=expected.norm.astype('float64'))
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(expected, result)
|
2017-10-10 18:32:01 -04:00
|
|
|
|
2017-10-20 15:44:30 -04:00
|
|
|
def test_datatype_string(self):
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
2017-10-20 15:44:30 -04:00
|
|
|
df = self.data
|
|
|
|
|
|
|
|
foo_udf = pandas_udf(
|
|
|
|
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
2017-11-17 10:43:08 -05:00
|
|
|
'id long, v int, v1 double, v2 long',
|
2018-01-30 07:55:55 -05:00
|
|
|
PandasUDFType.GROUPED_MAP
|
2017-11-17 10:43:08 -05:00
|
|
|
)
|
2017-10-20 15:44:30 -04:00
|
|
|
|
|
|
|
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
|
|
|
|
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
|
2018-01-23 00:11:30 -05:00
|
|
|
self.assertPandasEqual(expected, result)
|
2017-10-20 15:44:30 -04:00
|
|
|
|
2017-10-10 18:32:01 -04:00
|
|
|
def test_wrong_return_type(self):
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
2017-10-10 18:32:01 -04:00
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
NotImplementedError,
|
|
|
|
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
|
|
|
|
pandas_udf(
|
|
|
|
lambda pdf: pdf,
|
|
|
|
'id long, v map<int, int>',
|
|
|
|
PandasUDFType.GROUPED_MAP)
|
2017-10-10 18:32:01 -04:00
|
|
|
|
|
|
|
def test_wrong_args(self):
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType
|
2017-10-10 18:32:01 -04:00
|
|
|
df = self.data
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
2017-11-17 10:43:08 -05:00
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
2017-10-10 18:32:01 -04:00
|
|
|
df.groupby('id').apply(lambda x: x)
|
2017-11-17 10:43:08 -05:00
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
2017-10-10 18:32:01 -04:00
|
|
|
df.groupby('id').apply(udf(lambda x: x, DoubleType()))
|
2017-11-17 10:43:08 -05:00
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
2017-10-10 18:32:01 -04:00
|
|
|
df.groupby('id').apply(sum(df.v))
|
2017-11-17 10:43:08 -05:00
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
2017-10-10 18:32:01 -04:00
|
|
|
df.groupby('id').apply(df.v + 1)
|
2017-11-17 10:43:08 -05:00
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
2017-10-20 15:44:30 -04:00
|
|
|
df.groupby('id').apply(
|
|
|
|
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
|
2017-11-17 10:43:08 -05:00
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
|
2018-01-30 07:55:55 -05:00
|
|
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
|
2017-11-17 10:43:08 -05:00
|
|
|
df.groupby('id').apply(
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
|
2017-10-10 18:32:01 -04:00
|
|
|
|
2017-10-20 15:44:30 -04:00
|
|
|
def test_unsupported_types(self):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
2017-10-20 15:44:30 -04:00
|
|
|
schema = StructType(
|
2017-12-26 07:37:25 -05:00
|
|
|
[StructField("id", LongType(), True),
|
|
|
|
StructField("map", MapType(StringType(), IntegerType()), True)])
|
2017-10-20 15:44:30 -04:00
|
|
|
with QuietTest(self.sc):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
NotImplementedError,
|
|
|
|
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
|
|
|
|
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
|
|
|
|
|
|
|
|
schema = StructType(
|
|
|
|
[StructField("id", LongType(), True),
|
|
|
|
StructField("arr_ts", ArrayType(TimestampType()), True)])
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
NotImplementedError,
|
|
|
|
'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
|
|
|
|
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
|
2017-10-20 15:44:30 -04:00
|
|
|
|
2018-02-11 03:31:35 -05:00
|
|
|
# Regression test for SPARK-23314
|
|
|
|
def test_timestamp_dst(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
|
|
|
|
dt = [datetime.datetime(2015, 11, 1, 0, 30),
|
|
|
|
datetime.datetime(2015, 11, 1, 1, 30),
|
|
|
|
datetime.datetime(2015, 11, 1, 2, 30)]
|
|
|
|
df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
|
|
|
|
foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP)
|
|
|
|
result = df.groupby('time').apply(foo_udf).sort('time')
|
|
|
|
self.assertPandasEqual(df.toPandas(), result.toPandas())
|
|
|
|
|
2018-03-08 06:29:07 -05:00
|
|
|
def test_udf_with_key(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
|
|
|
|
df = self.data
|
|
|
|
pdf = df.toPandas()
|
|
|
|
|
|
|
|
def foo1(key, pdf):
|
|
|
|
import numpy as np
|
|
|
|
assert type(key) == tuple
|
|
|
|
assert type(key[0]) == np.int64
|
|
|
|
|
|
|
|
return pdf.assign(v1=key[0],
|
|
|
|
v2=pdf.v * key[0],
|
|
|
|
v3=pdf.v * pdf.id,
|
|
|
|
v4=pdf.v * pdf.id.mean())
|
|
|
|
|
|
|
|
def foo2(key, pdf):
|
|
|
|
import numpy as np
|
|
|
|
assert type(key) == tuple
|
|
|
|
assert type(key[0]) == np.int64
|
|
|
|
assert type(key[1]) == np.int32
|
|
|
|
|
|
|
|
return pdf.assign(v1=key[0],
|
|
|
|
v2=key[1],
|
|
|
|
v3=pdf.v * key[0],
|
|
|
|
v4=pdf.v + key[1])
|
|
|
|
|
|
|
|
def foo3(key, pdf):
|
|
|
|
assert type(key) == tuple
|
|
|
|
assert len(key) == 0
|
|
|
|
return pdf.assign(v1=pdf.v * pdf.id)
|
|
|
|
|
|
|
|
# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
|
|
|
|
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
|
|
|
|
udf1 = pandas_udf(
|
|
|
|
foo1,
|
|
|
|
'id long, v int, v1 long, v2 int, v3 long, v4 double',
|
|
|
|
PandasUDFType.GROUPED_MAP)
|
|
|
|
|
|
|
|
udf2 = pandas_udf(
|
|
|
|
foo2,
|
|
|
|
'id long, v int, v1 long, v2 int, v3 int, v4 int',
|
|
|
|
PandasUDFType.GROUPED_MAP)
|
|
|
|
|
|
|
|
udf3 = pandas_udf(
|
|
|
|
foo3,
|
|
|
|
'id long, v int, v1 long',
|
|
|
|
PandasUDFType.GROUPED_MAP)
|
|
|
|
|
|
|
|
# Test groupby column
|
|
|
|
result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
|
|
|
|
expected1 = pdf.groupby('id')\
|
|
|
|
.apply(lambda x: udf1.func((x.id.iloc[0],), x))\
|
|
|
|
.sort_values(['id', 'v']).reset_index(drop=True)
|
|
|
|
self.assertPandasEqual(expected1, result1)
|
|
|
|
|
|
|
|
# Test groupby expression
|
|
|
|
result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
|
|
|
|
expected2 = pdf.groupby(pdf.id % 2)\
|
|
|
|
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
|
|
|
|
.sort_values(['id', 'v']).reset_index(drop=True)
|
|
|
|
self.assertPandasEqual(expected2, result2)
|
|
|
|
|
|
|
|
# Test complex groupby
|
|
|
|
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
|
|
|
|
expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
|
|
|
|
.apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
|
|
|
|
.sort_values(['id', 'v']).reset_index(drop=True)
|
|
|
|
self.assertPandasEqual(expected3, result3)
|
|
|
|
|
|
|
|
# Test empty groupby
|
|
|
|
result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
|
|
|
|
expected4 = udf3.func((), pdf)
|
|
|
|
self.assertPandasEqual(expected4, result4)
|
|
|
|
|
2017-10-10 18:32:01 -04:00
|
|
|
|
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request?
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test.
We declared the extra dependencies:
https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204
In case of PyArrow:
Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed:
```
======================================================================
ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf
require_minimum_pyarrow_version()
File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version
"however, your version was %s." % pyarrow.__version__)
ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0.
----------------------------------------------------------------------
Ran 33 tests in 8.098s
FAILED (errors=33)
```
In case of Pandas:
There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing.
## How was this patch tested?
Manually tested by modifying the condition:
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.'
```
```
test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.'
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
2018-02-07 09:28:10 -05:00
|
|
|
@unittest.skipIf(
|
|
|
|
not _have_pandas or not _have_pyarrow,
|
|
|
|
_pandas_requirement_message or _pyarrow_requirement_message)
|
2018-02-06 15:30:04 -05:00
|
|
|
class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
2018-01-23 00:11:30 -05:00
|
|
|
|
|
|
|
@property
|
|
|
|
def data(self):
|
|
|
|
from pyspark.sql.functions import array, explode, col, lit
|
|
|
|
return self.spark.range(10).toDF('id') \
|
|
|
|
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
|
|
|
|
.withColumn("v", explode(col('vs'))) \
|
|
|
|
.drop('vs') \
|
|
|
|
.withColumn('w', lit(1.0))
|
|
|
|
|
|
|
|
@property
|
|
|
|
def python_plus_one(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
|
|
|
|
@udf('double')
|
|
|
|
def plus_one(v):
|
|
|
|
assert isinstance(v, (int, float))
|
|
|
|
return v + 1
|
|
|
|
return plus_one
|
|
|
|
|
|
|
|
@property
|
|
|
|
def pandas_scalar_plus_two(self):
|
|
|
|
import pandas as pd
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
|
|
|
|
@pandas_udf('double', PandasUDFType.SCALAR)
|
|
|
|
def plus_two(v):
|
|
|
|
assert isinstance(v, pd.Series)
|
|
|
|
return v + 2
|
|
|
|
return plus_two
|
|
|
|
|
|
|
|
@property
|
|
|
|
def pandas_agg_mean_udf(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
2018-01-23 00:11:30 -05:00
|
|
|
def avg(v):
|
|
|
|
return v.mean()
|
|
|
|
return avg
|
|
|
|
|
|
|
|
@property
|
|
|
|
def pandas_agg_sum_udf(self):
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
2018-01-23 00:11:30 -05:00
|
|
|
def sum(v):
|
|
|
|
return v.sum()
|
|
|
|
return sum
|
|
|
|
|
|
|
|
@property
|
|
|
|
def pandas_agg_weighted_mean_udf(self):
|
|
|
|
import numpy as np
|
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
2018-01-23 00:11:30 -05:00
|
|
|
def weighted_mean(v, w):
|
|
|
|
return np.average(v, weights=w)
|
|
|
|
return weighted_mean
|
|
|
|
|
|
|
|
def test_manual(self):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, array
|
|
|
|
|
2018-01-23 00:11:30 -05:00
|
|
|
df = self.data
|
|
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
mean_udf = self.pandas_agg_mean_udf
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
mean_arr_udf = pandas_udf(
|
|
|
|
self.pandas_agg_mean_udf.func,
|
|
|
|
ArrayType(self.pandas_agg_mean_udf.returnType),
|
|
|
|
self.pandas_agg_mean_udf.evalType)
|
|
|
|
|
|
|
|
result1 = df.groupby('id').agg(
|
|
|
|
sum_udf(df.v),
|
|
|
|
mean_udf(df.v),
|
|
|
|
mean_arr_udf(array(df.v))).sort('id')
|
2018-01-23 00:11:30 -05:00
|
|
|
expected1 = self.spark.createDataFrame(
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
[[0, 245.0, 24.5, [24.5]],
|
|
|
|
[1, 255.0, 25.5, [25.5]],
|
|
|
|
[2, 265.0, 26.5, [26.5]],
|
|
|
|
[3, 275.0, 27.5, [27.5]],
|
|
|
|
[4, 285.0, 28.5, [28.5]],
|
|
|
|
[5, 295.0, 29.5, [29.5]],
|
|
|
|
[6, 305.0, 30.5, [30.5]],
|
|
|
|
[7, 315.0, 31.5, [31.5]],
|
|
|
|
[8, 325.0, 32.5, [32.5]],
|
|
|
|
[9, 335.0, 33.5, [33.5]]],
|
|
|
|
['id', 'sum(v)', 'avg(v)', 'avg(array(v))'])
|
2018-01-23 00:11:30 -05:00
|
|
|
|
|
|
|
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
|
|
|
|
|
|
|
def test_basic(self):
|
|
|
|
from pyspark.sql.functions import col, lit, sum, mean
|
|
|
|
|
|
|
|
df = self.data
|
|
|
|
weighted_mean_udf = self.pandas_agg_weighted_mean_udf
|
|
|
|
|
|
|
|
# Groupby one column and aggregate one UDF with literal
|
|
|
|
result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id')
|
|
|
|
expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
|
|
|
|
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
|
|
|
|
|
|
|
# Groupby one expression and aggregate one UDF with literal
|
|
|
|
result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\
|
|
|
|
.sort(df.id + 1)
|
|
|
|
expected2 = df.groupby((col('id') + 1))\
|
|
|
|
.agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1)
|
|
|
|
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
|
|
|
|
|
|
|
|
# Groupby one column and aggregate one UDF without literal
|
|
|
|
result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id')
|
|
|
|
expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id')
|
|
|
|
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
|
|
|
|
|
|
|
|
# Groupby one expression and aggregate one UDF without literal
|
|
|
|
result4 = df.groupby((col('id') + 1).alias('id'))\
|
|
|
|
.agg(weighted_mean_udf(df.v, df.w))\
|
|
|
|
.sort('id')
|
|
|
|
expected4 = df.groupby((col('id') + 1).alias('id'))\
|
|
|
|
.agg(mean(df.v).alias('weighted_mean(v, w)'))\
|
|
|
|
.sort('id')
|
|
|
|
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
|
|
|
|
|
|
|
|
def test_unsupported_types(self):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
from pyspark.sql.types import DoubleType, MapType
|
2018-01-23 00:11:30 -05:00
|
|
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
2018-02-01 01:26:27 -05:00
|
|
|
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
pandas_udf(
|
|
|
|
lambda x: x,
|
|
|
|
ArrayType(ArrayType(TimestampType())),
|
|
|
|
PandasUDFType.GROUPED_AGG)
|
2018-01-23 00:11:30 -05:00
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
2018-02-01 01:26:27 -05:00
|
|
|
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG)
|
2018-01-23 00:11:30 -05:00
|
|
|
def mean_and_std_udf(v):
|
|
|
|
return v.mean(), v.std()
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
2018-02-01 01:26:27 -05:00
|
|
|
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
2018-01-30 07:55:55 -05:00
|
|
|
@pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
|
2018-01-23 00:11:30 -05:00
|
|
|
def mean_and_std_udf(v):
|
|
|
|
return {v.mean(): v.std()}
|
|
|
|
|
|
|
|
def test_alias(self):
|
|
|
|
from pyspark.sql.functions import mean
|
|
|
|
|
|
|
|
df = self.data
|
|
|
|
mean_udf = self.pandas_agg_mean_udf
|
|
|
|
|
|
|
|
result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias'))
|
|
|
|
expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias'))
|
|
|
|
|
|
|
|
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
|
|
|
|
|
|
|
def test_mixed_sql(self):
|
|
|
|
"""
|
|
|
|
Test mixing group aggregate pandas UDF with sql expression.
|
|
|
|
"""
|
|
|
|
from pyspark.sql.functions import sum, mean
|
|
|
|
|
|
|
|
df = self.data
|
|
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
|
|
|
|
# Mix group aggregate pandas UDF with sql expression
|
|
|
|
result1 = (df.groupby('id')
|
|
|
|
.agg(sum_udf(df.v) + 1)
|
|
|
|
.sort('id'))
|
|
|
|
expected1 = (df.groupby('id')
|
|
|
|
.agg(sum(df.v) + 1)
|
|
|
|
.sort('id'))
|
|
|
|
|
|
|
|
# Mix group aggregate pandas UDF with sql expression (order swapped)
|
|
|
|
result2 = (df.groupby('id')
|
|
|
|
.agg(sum_udf(df.v + 1))
|
|
|
|
.sort('id'))
|
|
|
|
|
|
|
|
expected2 = (df.groupby('id')
|
|
|
|
.agg(sum(df.v + 1))
|
|
|
|
.sort('id'))
|
|
|
|
|
|
|
|
# Wrap group aggregate pandas UDF with two sql expressions
|
|
|
|
result3 = (df.groupby('id')
|
|
|
|
.agg(sum_udf(df.v + 1) + 2)
|
|
|
|
.sort('id'))
|
|
|
|
expected3 = (df.groupby('id')
|
|
|
|
.agg(sum(df.v + 1) + 2)
|
|
|
|
.sort('id'))
|
|
|
|
|
|
|
|
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
|
|
|
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
|
|
|
|
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
|
|
|
|
|
|
|
|
def test_mixed_udfs(self):
|
|
|
|
"""
|
|
|
|
Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
|
|
|
|
"""
|
|
|
|
from pyspark.sql.functions import sum, mean
|
|
|
|
|
|
|
|
df = self.data
|
|
|
|
plus_one = self.python_plus_one
|
|
|
|
plus_two = self.pandas_scalar_plus_two
|
|
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
|
|
|
|
# Mix group aggregate pandas UDF and python UDF
|
|
|
|
result1 = (df.groupby('id')
|
|
|
|
.agg(plus_one(sum_udf(df.v)))
|
|
|
|
.sort('id'))
|
|
|
|
expected1 = (df.groupby('id')
|
|
|
|
.agg(plus_one(sum(df.v)))
|
|
|
|
.sort('id'))
|
|
|
|
|
|
|
|
# Mix group aggregate pandas UDF and python UDF (order swapped)
|
|
|
|
result2 = (df.groupby('id')
|
|
|
|
.agg(sum_udf(plus_one(df.v)))
|
|
|
|
.sort('id'))
|
|
|
|
expected2 = (df.groupby('id')
|
|
|
|
.agg(sum(plus_one(df.v)))
|
|
|
|
.sort('id'))
|
|
|
|
|
|
|
|
# Mix group aggregate pandas UDF and scalar pandas UDF
|
|
|
|
result3 = (df.groupby('id')
|
|
|
|
.agg(sum_udf(plus_two(df.v)))
|
|
|
|
.sort('id'))
|
|
|
|
expected3 = (df.groupby('id')
|
|
|
|
.agg(sum(plus_two(df.v)))
|
|
|
|
.sort('id'))
|
|
|
|
|
|
|
|
# Mix group aggregate pandas UDF and scalar pandas UDF (order swapped)
|
|
|
|
result4 = (df.groupby('id')
|
|
|
|
.agg(plus_two(sum_udf(df.v)))
|
|
|
|
.sort('id'))
|
|
|
|
expected4 = (df.groupby('id')
|
|
|
|
.agg(plus_two(sum(df.v)))
|
|
|
|
.sort('id'))
|
|
|
|
|
|
|
|
# Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby
|
|
|
|
result5 = (df.groupby(plus_one(df.id))
|
|
|
|
.agg(plus_one(sum_udf(plus_one(df.v))))
|
|
|
|
.sort('plus_one(id)'))
|
|
|
|
expected5 = (df.groupby(plus_one(df.id))
|
|
|
|
.agg(plus_one(sum(plus_one(df.v))))
|
|
|
|
.sort('plus_one(id)'))
|
|
|
|
|
|
|
|
# Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in
|
|
|
|
# groupby
|
|
|
|
result6 = (df.groupby(plus_two(df.id))
|
|
|
|
.agg(plus_two(sum_udf(plus_two(df.v))))
|
|
|
|
.sort('plus_two(id)'))
|
|
|
|
expected6 = (df.groupby(plus_two(df.id))
|
|
|
|
.agg(plus_two(sum(plus_two(df.v))))
|
|
|
|
.sort('plus_two(id)'))
|
|
|
|
|
|
|
|
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
|
|
|
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
|
|
|
|
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
|
|
|
|
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
|
|
|
|
self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
|
|
|
|
self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
|
|
|
|
|
|
|
|
def test_multiple_udfs(self):
|
|
|
|
"""
|
|
|
|
Test multiple group aggregate pandas UDFs in one agg function.
|
|
|
|
"""
|
|
|
|
from pyspark.sql.functions import col, lit, sum, mean
|
|
|
|
|
|
|
|
df = self.data
|
|
|
|
mean_udf = self.pandas_agg_mean_udf
|
|
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
weighted_mean_udf = self.pandas_agg_weighted_mean_udf
|
|
|
|
|
|
|
|
result1 = (df.groupBy('id')
|
|
|
|
.agg(mean_udf(df.v),
|
|
|
|
sum_udf(df.v),
|
|
|
|
weighted_mean_udf(df.v, df.w))
|
|
|
|
.sort('id')
|
|
|
|
.toPandas())
|
|
|
|
expected1 = (df.groupBy('id')
|
|
|
|
.agg(mean(df.v),
|
|
|
|
sum(df.v),
|
|
|
|
mean(df.v).alias('weighted_mean(v, w)'))
|
|
|
|
.sort('id')
|
|
|
|
.toPandas())
|
|
|
|
|
|
|
|
self.assertPandasEqual(expected1, result1)
|
|
|
|
|
|
|
|
def test_complex_groupby(self):
|
|
|
|
from pyspark.sql.functions import lit, sum
|
|
|
|
|
|
|
|
df = self.data
|
|
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
plus_one = self.python_plus_one
|
|
|
|
plus_two = self.pandas_scalar_plus_two
|
|
|
|
|
|
|
|
# groupby one expression
|
|
|
|
result1 = df.groupby(df.v % 2).agg(sum_udf(df.v))
|
|
|
|
expected1 = df.groupby(df.v % 2).agg(sum(df.v))
|
|
|
|
|
|
|
|
# empty groupby
|
|
|
|
result2 = df.groupby().agg(sum_udf(df.v))
|
|
|
|
expected2 = df.groupby().agg(sum(df.v))
|
|
|
|
|
|
|
|
# groupby one column and one sql expression
|
|
|
|
result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v))
|
|
|
|
expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v))
|
|
|
|
|
|
|
|
# groupby one python UDF
|
|
|
|
result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v))
|
|
|
|
expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v))
|
|
|
|
|
|
|
|
# groupby one scalar pandas UDF
|
|
|
|
result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v))
|
|
|
|
expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v))
|
|
|
|
|
|
|
|
# groupby one expression and one python UDF
|
|
|
|
result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v))
|
|
|
|
expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v))
|
|
|
|
|
|
|
|
# groupby one expression and one scalar pandas UDF
|
|
|
|
result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)')
|
|
|
|
expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)')
|
|
|
|
|
|
|
|
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
|
|
|
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
|
|
|
|
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
|
|
|
|
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
|
|
|
|
self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
|
|
|
|
self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
|
|
|
|
self.assertPandasEqual(expected7.toPandas(), result7.toPandas())
|
|
|
|
|
|
|
|
def test_complex_expressions(self):
|
|
|
|
from pyspark.sql.functions import col, sum
|
|
|
|
|
|
|
|
df = self.data
|
|
|
|
plus_one = self.python_plus_one
|
|
|
|
plus_two = self.pandas_scalar_plus_two
|
|
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
|
|
|
|
# Test complex expressions with sql expression, python UDF and
|
|
|
|
# group aggregate pandas UDF
|
|
|
|
result1 = (df.withColumn('v1', plus_one(df.v))
|
|
|
|
.withColumn('v2', df.v + 2)
|
|
|
|
.groupby(df.id, df.v % 2)
|
|
|
|
.agg(sum_udf(col('v')),
|
|
|
|
sum_udf(col('v1') + 3),
|
|
|
|
sum_udf(col('v2')) + 5,
|
|
|
|
plus_one(sum_udf(col('v1'))),
|
|
|
|
sum_udf(plus_one(col('v2'))))
|
|
|
|
.sort('id')
|
|
|
|
.toPandas())
|
|
|
|
|
|
|
|
expected1 = (df.withColumn('v1', df.v + 1)
|
|
|
|
.withColumn('v2', df.v + 2)
|
|
|
|
.groupby(df.id, df.v % 2)
|
|
|
|
.agg(sum(col('v')),
|
|
|
|
sum(col('v1') + 3),
|
|
|
|
sum(col('v2')) + 5,
|
|
|
|
plus_one(sum(col('v1'))),
|
|
|
|
sum(plus_one(col('v2'))))
|
|
|
|
.sort('id')
|
|
|
|
.toPandas())
|
|
|
|
|
|
|
|
# Test complex expressions with sql expression, scala pandas UDF and
|
|
|
|
# group aggregate pandas UDF
|
|
|
|
result2 = (df.withColumn('v1', plus_one(df.v))
|
|
|
|
.withColumn('v2', df.v + 2)
|
|
|
|
.groupby(df.id, df.v % 2)
|
|
|
|
.agg(sum_udf(col('v')),
|
|
|
|
sum_udf(col('v1') + 3),
|
|
|
|
sum_udf(col('v2')) + 5,
|
|
|
|
plus_two(sum_udf(col('v1'))),
|
|
|
|
sum_udf(plus_two(col('v2'))))
|
|
|
|
.sort('id')
|
|
|
|
.toPandas())
|
|
|
|
|
|
|
|
expected2 = (df.withColumn('v1', df.v + 1)
|
|
|
|
.withColumn('v2', df.v + 2)
|
|
|
|
.groupby(df.id, df.v % 2)
|
|
|
|
.agg(sum(col('v')),
|
|
|
|
sum(col('v1') + 3),
|
|
|
|
sum(col('v2')) + 5,
|
|
|
|
plus_two(sum(col('v1'))),
|
|
|
|
sum(plus_two(col('v2'))))
|
|
|
|
.sort('id')
|
|
|
|
.toPandas())
|
|
|
|
|
|
|
|
# Test sequential groupby aggregate
|
|
|
|
result3 = (df.groupby('id')
|
|
|
|
.agg(sum_udf(df.v).alias('v'))
|
|
|
|
.groupby('id')
|
|
|
|
.agg(sum_udf(col('v')))
|
|
|
|
.sort('id')
|
|
|
|
.toPandas())
|
|
|
|
|
|
|
|
expected3 = (df.groupby('id')
|
|
|
|
.agg(sum(df.v).alias('v'))
|
|
|
|
.groupby('id')
|
|
|
|
.agg(sum(col('v')))
|
|
|
|
.sort('id')
|
|
|
|
.toPandas())
|
|
|
|
|
|
|
|
self.assertPandasEqual(expected1, result1)
|
|
|
|
self.assertPandasEqual(expected2, result2)
|
|
|
|
self.assertPandasEqual(expected3, result3)
|
|
|
|
|
|
|
|
def test_retain_group_columns(self):
|
|
|
|
from pyspark.sql.functions import sum, lit, col
|
|
|
|
orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None)
|
|
|
|
self.spark.conf.set("spark.sql.retainGroupColumns", False)
|
|
|
|
try:
|
|
|
|
df = self.data
|
|
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
|
|
|
|
result1 = df.groupby(df.id).agg(sum_udf(df.v))
|
|
|
|
expected1 = df.groupby(df.id).agg(sum(df.v))
|
|
|
|
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
|
|
|
|
|
|
|
finally:
|
|
|
|
if orig_value is None:
|
|
|
|
self.spark.conf.unset("spark.sql.retainGroupColumns")
|
|
|
|
else:
|
|
|
|
self.spark.conf.set("spark.sql.retainGroupColumns", orig_value)
|
|
|
|
|
|
|
|
def test_invalid_args(self):
|
|
|
|
from pyspark.sql.functions import mean
|
|
|
|
|
|
|
|
df = self.data
|
|
|
|
plus_one = self.python_plus_one
|
|
|
|
mean_udf = self.pandas_agg_mean_udf
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
'nor.*aggregate function'):
|
|
|
|
df.groupby(df.id).agg(plus_one(df.v)).collect()
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
'aggregate function.*argument.*aggregate function'):
|
|
|
|
df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
'mixture.*aggregate function.*group aggregate pandas UDF'):
|
|
|
|
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
if __name__ == "__main__":
|
2016-01-20 14:11:10 -05:00
|
|
|
from pyspark.sql.tests import *
|
2015-10-22 18:27:11 -04:00
|
|
|
if xmlrunner:
|
|
|
|
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
|
|
|
|
else:
|
|
|
|
unittest.main()
|