2015-07-01 19:43:18 -04:00
|
|
|
# -*- encoding: utf-8 -*-
|
2015-02-03 19:01:56 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
"""
|
|
|
|
Unit tests for pyspark.sql; additional tests are implemented as doctests in
|
|
|
|
individual modules.
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
import sys
|
2016-06-28 10:54:44 -04:00
|
|
|
import subprocess
|
2015-02-03 19:01:56 -05:00
|
|
|
import pydoc
|
|
|
|
import shutil
|
|
|
|
import tempfile
|
2015-02-27 23:07:17 -05:00
|
|
|
import pickle
|
2015-04-01 20:23:57 -04:00
|
|
|
import functools
|
2015-06-11 04:00:41 -04:00
|
|
|
import time
|
2015-04-21 03:08:18 -04:00
|
|
|
import datetime
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-02-17 18:44:37 -05:00
|
|
|
import py4j
|
2015-10-22 18:27:11 -04:00
|
|
|
try:
|
|
|
|
import xmlrunner
|
|
|
|
except ImportError:
|
|
|
|
xmlrunner = None
|
2015-02-17 18:44:37 -05:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
if sys.version_info[:2] <= (2, 6):
|
|
|
|
try:
|
|
|
|
import unittest2 as unittest
|
|
|
|
except ImportError:
|
|
|
|
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
|
|
|
|
sys.exit(1)
|
|
|
|
else:
|
|
|
|
import unittest
|
|
|
|
|
2017-01-12 07:53:31 -05:00
|
|
|
from pyspark import SparkContext
|
2017-01-13 05:35:12 -05:00
|
|
|
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
|
2015-02-24 23:51:55 -05:00
|
|
|
from pyspark.sql.types import *
|
|
|
|
from pyspark.sql.types import UserDefinedType, _infer_type
|
2016-06-28 10:54:44 -04:00
|
|
|
from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
|
2016-12-15 17:26:54 -05:00
|
|
|
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
|
2015-05-23 11:30:05 -04:00
|
|
|
from pyspark.sql.window import Window
|
2016-03-28 15:31:12 -04:00
|
|
|
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
class UTCOffsetTimezone(datetime.tzinfo):
|
|
|
|
"""
|
|
|
|
Specifies timezone in UTC offset
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, offset=0):
|
|
|
|
self.ZERO = datetime.timedelta(hours=offset)
|
2015-06-11 04:00:41 -04:00
|
|
|
|
|
|
|
def utcoffset(self, dt):
|
|
|
|
return self.ZERO
|
|
|
|
|
|
|
|
def dst(self, dt):
|
|
|
|
return self.ZERO
|
|
|
|
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
class ExamplePointUDT(UserDefinedType):
|
|
|
|
"""
|
|
|
|
User-defined type (UDT) for ExamplePoint.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def sqlType(self):
|
|
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def module(cls):
|
2015-07-30 01:30:49 -04:00
|
|
|
return 'pyspark.sql.tests'
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def scalaUDT(cls):
|
|
|
|
return 'org.apache.spark.sql.test.ExamplePointUDT'
|
|
|
|
|
|
|
|
def serialize(self, obj):
|
|
|
|
return [obj.x, obj.y]
|
|
|
|
|
|
|
|
def deserialize(self, datum):
|
|
|
|
return ExamplePoint(datum[0], datum[1])
|
|
|
|
|
|
|
|
|
|
|
|
class ExamplePoint:
|
|
|
|
"""
|
|
|
|
An example class to demonstrate UDT in Scala, Java, and Python.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__UDT__ = ExamplePointUDT()
|
|
|
|
|
|
|
|
def __init__(self, x, y):
|
|
|
|
self.x = x
|
|
|
|
self.y = y
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "ExamplePoint(%s,%s)" % (self.x, self.y)
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return "(%s,%s)" % (self.x, self.y)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2015-07-30 01:30:49 -04:00
|
|
|
return isinstance(other, self.__class__) and \
|
2015-02-03 19:01:56 -05:00
|
|
|
other.x == self.x and other.y == self.y
|
|
|
|
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
class PythonOnlyUDT(UserDefinedType):
|
|
|
|
"""
|
|
|
|
User-defined type (UDT) for ExamplePoint.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def sqlType(self):
|
|
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def module(cls):
|
|
|
|
return '__main__'
|
|
|
|
|
|
|
|
def serialize(self, obj):
|
|
|
|
return [obj.x, obj.y]
|
|
|
|
|
|
|
|
def deserialize(self, datum):
|
|
|
|
return PythonOnlyPoint(datum[0], datum[1])
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def foo():
|
|
|
|
pass
|
|
|
|
|
|
|
|
@property
|
|
|
|
def props(self):
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
class PythonOnlyPoint(ExamplePoint):
|
|
|
|
"""
|
|
|
|
An example class to demonstrate UDT in only Python
|
|
|
|
"""
|
|
|
|
__UDT__ = PythonOnlyUDT()
|
|
|
|
|
|
|
|
|
2015-08-26 19:04:44 -04:00
|
|
|
class MyObject(object):
|
|
|
|
def __init__(self, key, value):
|
|
|
|
self.key = key
|
|
|
|
self.value = value
|
|
|
|
|
|
|
|
|
2015-02-27 23:07:17 -05:00
|
|
|
class DataTypeTests(unittest.TestCase):
|
|
|
|
# regression test for SPARK-6055
|
|
|
|
def test_data_type_eq(self):
|
|
|
|
lt = LongType()
|
|
|
|
lt2 = pickle.loads(pickle.dumps(LongType()))
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(lt, lt2)
|
2015-02-27 23:07:17 -05:00
|
|
|
|
2015-05-31 22:55:57 -04:00
|
|
|
# regression test for SPARK-7978
|
|
|
|
def test_decimal_type(self):
|
|
|
|
t1 = DecimalType()
|
|
|
|
t2 = DecimalType(10, 2)
|
|
|
|
self.assertTrue(t2 is not t1)
|
|
|
|
self.assertNotEqual(t1, t2)
|
|
|
|
t3 = DecimalType(8)
|
|
|
|
self.assertNotEqual(t2, t3)
|
|
|
|
|
2015-09-01 17:58:49 -04:00
|
|
|
# regression test for SPARK-10392
|
|
|
|
def test_datetype_equal_zero(self):
|
|
|
|
dt = DateType()
|
|
|
|
self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
|
|
|
|
|
[SPARK-17035] [SQL] [PYSPARK] Improve Timestamp not to lose precision for all cases
## What changes were proposed in this pull request?
`PySpark` loses `microsecond` precision for some corner cases during converting `Timestamp` into `Long`. For example, for the following `datetime.max` value should be converted a value whose last 6 digits are '999999'. This PR improves the logic not to lose precision for all cases.
**Corner case**
```python
>>> datetime.datetime.max
datetime.datetime(9999, 12, 31, 23, 59, 59, 999999)
```
**Before**
```python
>>> from datetime import datetime
>>> from pyspark.sql import Row
>>> from pyspark.sql.types import StructType, StructField, TimestampType
>>> schema = StructType([StructField("dt", TimestampType(), False)])
>>> [schema.toInternal(row) for row in [{"dt": datetime.max}]]
[(253402329600000000,)]
```
**After**
```python
>>> [schema.toInternal(row) for row in [{"dt": datetime.max}]]
[(253402329599999999,)]
```
## How was this patch tested?
Pass the Jenkins test with a new test case.
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #14631 from dongjoon-hyun/SPARK-17035.
2016-08-16 13:01:30 -04:00
|
|
|
# regression test for SPARK-17035
|
|
|
|
def test_timestamp_microsecond(self):
|
|
|
|
tst = TimestampType()
|
|
|
|
self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999)
|
|
|
|
|
2016-06-21 13:53:33 -04:00
|
|
|
def test_empty_row(self):
|
|
|
|
row = Row()
|
|
|
|
self.assertEqual(len(row), 0)
|
|
|
|
|
2015-02-27 23:07:17 -05:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
class SQLTests(ReusedPySparkTestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ReusedPySparkTestCase.setUpClass()
|
|
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
|
|
|
os.unlink(cls.tempdir.name)
|
2016-05-11 14:24:16 -04:00
|
|
|
cls.spark = SparkSession(cls.sc)
|
2015-02-03 19:01:56 -05:00
|
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
2016-05-11 14:24:16 -04:00
|
|
|
cls.df = cls.spark.createDataFrame(cls.testData)
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ReusedPySparkTestCase.tearDownClass()
|
2016-05-11 14:24:16 -04:00
|
|
|
cls.spark.stop()
|
2015-02-03 19:01:56 -05:00
|
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
|
2017-01-13 05:35:12 -05:00
|
|
|
def test_sqlcontext_reuses_sparksession(self):
|
|
|
|
sqlContext1 = SQLContext(self.sc)
|
|
|
|
sqlContext2 = SQLContext(self.sc)
|
|
|
|
self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
|
|
|
|
|
2015-08-08 11:38:18 -04:00
|
|
|
def test_row_should_be_read_only(self):
|
|
|
|
row = Row(a=1, b=2)
|
|
|
|
self.assertEqual(1, row.a)
|
|
|
|
|
|
|
|
def foo():
|
|
|
|
row.a = 3
|
|
|
|
self.assertRaises(Exception, foo)
|
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
row2 = self.spark.range(10).first()
|
2015-08-08 11:38:18 -04:00
|
|
|
self.assertEqual(0, row2.id)
|
|
|
|
|
|
|
|
def foo2():
|
|
|
|
row2.id = 2
|
|
|
|
self.assertRaises(Exception, foo2)
|
|
|
|
|
2015-05-19 00:43:12 -04:00
|
|
|
def test_range(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.range(1, 1).count(), 0)
|
|
|
|
self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
|
|
|
|
self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
|
|
|
|
self.assertEqual(self.spark.range(-2).count(), 0)
|
|
|
|
self.assertEqual(self.spark.range(3).count(), 3)
|
2015-05-19 00:43:12 -04:00
|
|
|
|
2015-07-09 17:43:38 -04:00
|
|
|
def test_duplicated_column_names(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
|
2015-07-09 17:43:38 -04:00
|
|
|
row = df.select('*').first()
|
|
|
|
self.assertEqual(1, row[0])
|
|
|
|
self.assertEqual(2, row[1])
|
|
|
|
self.assertEqual("Row(c=1, c=2)", str(row))
|
|
|
|
# Cannot access columns
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
|
|
|
|
|
[SPARK-15244] [PYTHON] Type of column name created with createDataFrame is not consistent.
## What changes were proposed in this pull request?
**createDataFrame** returns inconsistent types for column names.
```python
>>> from pyspark.sql.types import StructType, StructField, StringType
>>> schema = StructType([StructField(u"col", StringType())])
>>> df1 = spark.createDataFrame([("a",)], schema)
>>> df1.columns # "col" is str
['col']
>>> df2 = spark.createDataFrame([("a",)], [u"col"])
>>> df2.columns # "col" is unicode
[u'col']
```
The reason is only **StructField** has the following code.
```
if not isinstance(name, str):
name = name.encode('utf-8')
```
This PR adds the same logic into **createDataFrame** for consistency.
```
if isinstance(schema, list):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
```
## How was this patch tested?
Pass the Jenkins test (with new python doctest)
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #13097 from dongjoon-hyun/SPARK-15244.
2016-05-17 16:05:07 -04:00
|
|
|
def test_column_name_encoding(self):
|
|
|
|
"""Ensure that created columns has `str` type consistently."""
|
|
|
|
columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns
|
|
|
|
self.assertEqual(columns, ['name', 'age'])
|
|
|
|
self.assertTrue(isinstance(columns[0], str))
|
|
|
|
self.assertTrue(isinstance(columns[1], str))
|
|
|
|
|
2015-05-14 22:49:44 -04:00
|
|
|
def test_explode(self):
|
|
|
|
from pyspark.sql.functions import explode
|
|
|
|
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
data = self.spark.createDataFrame(rdd)
|
2015-05-14 22:49:44 -04:00
|
|
|
|
|
|
|
result = data.select(explode(data.intlist).alias("a")).select("a").collect()
|
|
|
|
self.assertEqual(result[0][0], 1)
|
|
|
|
self.assertEqual(result[1][0], 2)
|
|
|
|
self.assertEqual(result[2][0], 3)
|
|
|
|
|
|
|
|
result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
|
|
|
|
self.assertEqual(result[0][0], "a")
|
|
|
|
self.assertEqual(result[0][1], "b")
|
|
|
|
|
2015-06-23 18:51:16 -04:00
|
|
|
def test_and_in_expression(self):
|
|
|
|
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
|
|
|
|
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
|
|
|
|
self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
|
|
|
|
self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
|
|
|
|
self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
|
|
|
|
self.assertRaises(ValueError, lambda: not self.df.key == 1)
|
|
|
|
|
2015-04-01 20:23:57 -04:00
|
|
|
def test_udf_with_callable(self):
|
|
|
|
d = [Row(number=i, squared=i**2) for i in range(10)]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
data = self.spark.createDataFrame(rdd)
|
2015-04-01 20:23:57 -04:00
|
|
|
|
|
|
|
class PlusFour:
|
|
|
|
def __call__(self, col):
|
|
|
|
if col is not None:
|
|
|
|
return col + 4
|
|
|
|
|
|
|
|
call = PlusFour()
|
|
|
|
pudf = UserDefinedFunction(call, LongType())
|
|
|
|
res = data.select(pudf(data['number']).alias('plus_four'))
|
|
|
|
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
|
|
|
|
|
|
|
|
def test_udf_with_partial_function(self):
|
|
|
|
d = [Row(number=i, squared=i**2) for i in range(10)]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
data = self.spark.createDataFrame(rdd)
|
2015-04-01 20:23:57 -04:00
|
|
|
|
|
|
|
def some_func(col, param):
|
|
|
|
if col is not None:
|
|
|
|
return col + param
|
|
|
|
|
|
|
|
pfunc = functools.partial(some_func, param=4)
|
|
|
|
pudf = UserDefinedFunction(pfunc, LongType())
|
|
|
|
res = data.select(pudf(data['number']).alias('plus_four'))
|
|
|
|
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_udf(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(row[0], 5)
|
|
|
|
|
|
|
|
def test_udf2(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
|
2016-05-17 21:01:59 -04:00
|
|
|
self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
|
|
|
|
.createOrReplaceTempView("test")
|
2016-05-11 14:24:16 -04:00
|
|
|
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(4, res[0])
|
|
|
|
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
def test_chained_udf(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(1)").collect()
|
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
2016-03-29 18:06:29 -04:00
|
|
|
self.assertEqual(row[0], 2)
|
2016-05-11 14:24:16 -04:00
|
|
|
[row] = self.spark.sql("SELECT double(double(1))").collect()
|
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
2016-03-29 18:06:29 -04:00
|
|
|
self.assertEqual(row[0], 4)
|
2016-05-11 14:24:16 -04:00
|
|
|
[row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
|
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
2016-03-29 18:06:29 -04:00
|
|
|
self.assertEqual(row[0], 6)
|
|
|
|
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
def test_multiple_udfs(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(1), double(2)").collect()
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
self.assertEqual(tuple(row), (2, 4))
|
2016-05-11 14:24:16 -04:00
|
|
|
[row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
self.assertEqual(tuple(row), (4, 12))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
self.assertEqual(tuple(row), (6, 5))
|
|
|
|
|
2016-09-19 16:24:16 -04:00
|
|
|
def test_udf_in_filter_on_top_of_outer_join(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
left = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
right = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
df = left.join(right, on='a', how='left_outer')
|
|
|
|
df = df.withColumn('b', udf(lambda x: 'x')(df.a))
|
|
|
|
self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
|
|
|
|
|
2017-01-20 19:11:40 -05:00
|
|
|
def test_udf_in_filter_on_top_of_join(self):
|
|
|
|
# regression test for SPARK-18589
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
left = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
right = self.spark.createDataFrame([Row(b=1)])
|
|
|
|
f = udf(lambda a, b: a == b, BooleanType())
|
|
|
|
df = left.crossJoin(right).filter(f("a", "b"))
|
|
|
|
self.assertEqual(df.collect(), [Row(a=1, b=1)])
|
|
|
|
|
2016-06-21 13:53:33 -04:00
|
|
|
def test_udf_without_arguments(self):
|
|
|
|
self.spark.catalog.registerFunction("foo", lambda: "bar")
|
|
|
|
[row] = self.spark.sql("SELECT foo()").collect()
|
|
|
|
self.assertEqual(row[0], "bar")
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_udf_with_array_type(self):
|
2015-04-16 19:20:57 -04:00
|
|
|
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
|
2015-02-03 19:01:56 -05:00
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-17 21:01:59 -04:00
|
|
|
self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
|
|
|
|
self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
|
|
|
|
[(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
|
2015-04-16 19:20:57 -04:00
|
|
|
self.assertEqual(list(range(3)), l1)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, l2)
|
|
|
|
|
|
|
|
def test_broadcast_in_udf(self):
|
|
|
|
bar = {"a": "aa", "b": "bb", "c": "abc"}
|
|
|
|
foo = self.sc.broadcast(bar)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
|
|
|
|
[res] = self.spark.sql("SELECT MYUDF('c')").collect()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual("abc", res[0])
|
2016-05-11 14:24:16 -04:00
|
|
|
[res] = self.spark.sql("SELECT MYUDF('')").collect()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual("", res[0])
|
|
|
|
|
[SPARK-18766][SQL] Push Down Filter Through BatchEvalPython (Python UDF)
### What changes were proposed in this pull request?
Currently, when users use Python UDF in Filter, BatchEvalPython is always generated below FilterExec. However, not all the predicates need to be evaluated after Python UDF execution. Thus, this PR is to push down the determinisitc predicates through `BatchEvalPython`.
```Python
>>> df = spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
>>> from pyspark.sql.functions import udf, col
>>> from pyspark.sql.types import BooleanType
>>> my_filter = udf(lambda a: a < 2, BooleanType())
>>> sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
>>> sel.explain(True)
```
Before the fix, the plan looks like
```
== Optimized Logical Plan ==
Filter ((isnotnull(value#1) && <lambda>(key#0L)) && (value#1 < 2))
+- LogicalRDD [key#0L, value#1]
== Physical Plan ==
*Project [key#0L, value#1]
+- *Filter ((isnotnull(value#1) && pythonUDF0#9) && (value#1 < 2))
+- BatchEvalPython [<lambda>(key#0L)], [key#0L, value#1, pythonUDF0#9]
+- Scan ExistingRDD[key#0L,value#1]
```
After the fix, the plan looks like
```
== Optimized Logical Plan ==
Filter ((isnotnull(value#1) && <lambda>(key#0L)) && (value#1 < 2))
+- LogicalRDD [key#0L, value#1]
== Physical Plan ==
*Project [key#0L, value#1]
+- *Filter pythonUDF0#9: boolean
+- BatchEvalPython [<lambda>(key#0L)], [key#0L, value#1, pythonUDF0#9]
+- *Filter (isnotnull(value#1) && (value#1 < 2))
+- Scan ExistingRDD[key#0L,value#1]
```
### How was this patch tested?
Added both unit test cases for `BatchEvalPythonExec` and also add an end-to-end test case in Python test suite.
Author: gatorsmile <gatorsmile@gmail.com>
Closes #16193 from gatorsmile/pythonUDFPredicatePushDown.
2016-12-10 11:47:45 -05:00
|
|
|
def test_udf_with_filter_function(self):
|
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
from pyspark.sql.functions import udf, col
|
|
|
|
from pyspark.sql.types import BooleanType
|
|
|
|
|
|
|
|
my_filter = udf(lambda a: a < 2, BooleanType())
|
|
|
|
sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
|
|
|
|
self.assertEqual(sel.collect(), [Row(key=1, value='1')])
|
|
|
|
|
2016-04-04 13:56:26 -04:00
|
|
|
def test_udf_with_aggregate_function(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
[SPARK-15888] [SQL] fix Python UDF with aggregate
## What changes were proposed in this pull request?
After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate.
## How was this patch tested?
Added regression tests. The plan of added test query looks like this:
```
== Parsed Logical Plan ==
'Project [<lambda>('k, 's) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Analyzed Logical Plan ==
t: int
Project [<lambda>(k#17, s#22L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Optimized Logical Plan ==
Project [<lambda>(agg#29, agg#30L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6) as bigint)) AS agg#30L]
+- LogicalRDD [key#5L, value#6]
== Physical Plan ==
*Project [pythonUDF0#37 AS t#26]
+- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37]
+- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L])
+- Exchange hashpartitioning(<lambda>(key#5L)#31, 200)
+- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[<lambda>(key#5L)#31,sum#33L])
+- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35]
+- Scan ExistingRDD[key#5L,value#6]
```
Author: Davies Liu <davies@databricks.com>
Closes #13682 from davies/fix_py_udf.
2016-06-15 16:38:04 -04:00
|
|
|
from pyspark.sql.functions import udf, col, sum
|
2016-04-04 13:56:26 -04:00
|
|
|
from pyspark.sql.types import BooleanType
|
|
|
|
|
|
|
|
my_filter = udf(lambda a: a == 1, BooleanType())
|
|
|
|
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
|
|
|
|
self.assertEqual(sel.collect(), [Row(key=1)])
|
|
|
|
|
[SPARK-15888] [SQL] fix Python UDF with aggregate
## What changes were proposed in this pull request?
After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate.
## How was this patch tested?
Added regression tests. The plan of added test query looks like this:
```
== Parsed Logical Plan ==
'Project [<lambda>('k, 's) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Analyzed Logical Plan ==
t: int
Project [<lambda>(k#17, s#22L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
+- LogicalRDD [key#5L, value#6]
== Optimized Logical Plan ==
Project [<lambda>(agg#29, agg#30L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6) as bigint)) AS agg#30L]
+- LogicalRDD [key#5L, value#6]
== Physical Plan ==
*Project [pythonUDF0#37 AS t#26]
+- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37]
+- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L])
+- Exchange hashpartitioning(<lambda>(key#5L)#31, 200)
+- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[<lambda>(key#5L)#31,sum#33L])
+- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35]
+- Scan ExistingRDD[key#5L,value#6]
```
Author: Davies Liu <davies@databricks.com>
Closes #13682 from davies/fix_py_udf.
2016-06-15 16:38:04 -04:00
|
|
|
my_copy = udf(lambda x: x, IntegerType())
|
|
|
|
my_add = udf(lambda a, b: int(a + b), IntegerType())
|
|
|
|
my_strlen = udf(lambda x: len(x), IntegerType())
|
|
|
|
sel = df.groupBy(my_copy(col("key")).alias("k"))\
|
|
|
|
.agg(sum(my_strlen(col("value"))).alias("s"))\
|
|
|
|
.select(my_add(col("k"), col("s")).alias("t"))
|
|
|
|
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
|
|
|
|
|
[SPARK-16179][PYSPARK] fix bugs for Python udf in generate
## What changes were proposed in this pull request?
This PR fix the bug when Python UDF is used in explode (generator), GenerateExec requires that all the attributes in expressions should be resolvable from children when creating, we should replace the children first, then replace it's expressions.
```
>>> df.select(explode(f(*df))).show()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/vlad/dev/spark/python/pyspark/sql/dataframe.py", line 286, in show
print(self._jdf.showString(n, truncate))
File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py", line 933, in __call__
File "/home/vlad/dev/spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o52.showString.
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree:
Generate explode(<lambda>(_1#0L)), false, false, [col#15L]
+- Scan ExistingRDD[_1#0L]
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50)
at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:387)
at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:69)
at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:45)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsDown(QueryPlan.scala:177)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressions(QueryPlan.scala:144)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(ExtractPythonUDFs.scala:153)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:114)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:113)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:300)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:298)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:113)
at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:93)
at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95)
at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95)
at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
at scala.collection.immutable.List.foldLeft(List.scala:84)
at org.apache.spark.sql.execution.QueryExecution.prepareForExecution(QueryExecution.scala:95)
at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:85)
at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:85)
at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2557)
at org.apache.spark.sql.Dataset.head(Dataset.scala:1923)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2138)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:239)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:280)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:211)
at java.lang.Thread.run(Thread.java:745)
Caused by: java.lang.reflect.InvocationTargetException
at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
at java.lang.reflect.Constructor.newInstance(Constructor.java:423)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:412)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:387)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49)
... 42 more
Caused by: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#20
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:88)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:87)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:278)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:284)
at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:268)
at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:87)
at org.apache.spark.sql.execution.GenerateExec.<init>(GenerateExec.scala:63)
... 52 more
Caused by: java.lang.RuntimeException: Couldn't find pythonUDF0#20 in [_1#0L]
at scala.sys.package$.error(package.scala:27)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:94)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:88)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49)
... 67 more
```
## How was this patch tested?
Added regression tests.
Author: Davies Liu <davies@databricks.com>
Closes #13883 from davies/udf_in_generate.
2016-06-24 18:20:39 -04:00
|
|
|
def test_udf_in_generate(self):
|
|
|
|
from pyspark.sql.functions import udf, explode
|
|
|
|
df = self.spark.range(5)
|
|
|
|
f = udf(lambda x: list(range(x)), ArrayType(LongType()))
|
|
|
|
row = df.select(explode(f(*df))).groupBy().sum().first()
|
|
|
|
self.assertEqual(row[0], 10)
|
|
|
|
|
[SPARK-18634][PYSPARK][SQL] Corruption and Correctness issues with exploding Python UDFs
## What changes were proposed in this pull request?
As reported in the Jira, there are some weird issues with exploding Python UDFs in SparkSQL.
The following test code can reproduce it. Notice: the following test code is reported to return wrong results in the Jira. However, as I tested on master branch, it causes exception and so can't return any result.
>>> from pyspark.sql.functions import *
>>> from pyspark.sql.types import *
>>>
>>> df = spark.range(10)
>>>
>>> def return_range(value):
... return [(i, str(i)) for i in range(value - 1, value + 1)]
...
>>> range_udf = udf(return_range, ArrayType(StructType([StructField("integer_val", IntegerType()),
... StructField("string_val", StringType())])))
>>>
>>> df.select("id", explode(range_udf(df.id))).show()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/spark/python/pyspark/sql/dataframe.py", line 318, in show
print(self._jdf.showString(n, 20))
File "/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__
File "/spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 319, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o126.showString.: java.lang.AssertionError: assertion failed
at scala.Predef$.assert(Predef.scala:156)
at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:120)
at org.apache.spark.sql.execution.GenerateExec.consume(GenerateExec.scala:57)
The cause of this issue is, in `ExtractPythonUDFs` we insert `BatchEvalPythonExec` to run PythonUDFs in batch. `BatchEvalPythonExec` will add extra outputs (e.g., `pythonUDF0`) to original plan. In above case, the original `Range` only has one output `id`. After `ExtractPythonUDFs`, the added `BatchEvalPythonExec` has two outputs `id` and `pythonUDF0`.
Because the output of `GenerateExec` is given after analysis phase, in above case, it is the combination of `id`, i.e., the output of `Range`, and `col`. But in planning phase, we change `GenerateExec`'s child plan to `BatchEvalPythonExec` with additional output attributes.
It will cause no problem in non wholestage codegen. Because when evaluating the additional attributes are projected out the final output of `GenerateExec`.
However, as `GenerateExec` now supports wholestage codegen, the framework will input all the outputs of the child plan to `GenerateExec`. Then when consuming `GenerateExec`'s output data (i.e., calling `consume`), the number of output attributes is different to the output variables in wholestage codegen.
To solve this issue, this patch only gives the generator's output to `GenerateExec` after analysis phase. `GenerateExec`'s output is the combination of its child plan's output and the generator's output. So when we change `GenerateExec`'s child, its output is still correct.
## How was this patch tested?
Added test cases to PySpark.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #16120 from viirya/fix-py-udf-with-generator.
2016-12-05 20:50:43 -05:00
|
|
|
df = self.spark.range(3)
|
|
|
|
res = df.select("id", explode(f(df.id))).collect()
|
|
|
|
self.assertEqual(res[0][0], 1)
|
|
|
|
self.assertEqual(res[0][1], 0)
|
|
|
|
self.assertEqual(res[1][0], 2)
|
|
|
|
self.assertEqual(res[1][1], 0)
|
|
|
|
self.assertEqual(res[2][0], 2)
|
|
|
|
self.assertEqual(res[2][1], 1)
|
|
|
|
|
|
|
|
range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
|
|
|
|
res = df.select("id", explode(range_udf(df.id))).collect()
|
|
|
|
self.assertEqual(res[0][0], 0)
|
|
|
|
self.assertEqual(res[0][1], -1)
|
|
|
|
self.assertEqual(res[1][0], 0)
|
|
|
|
self.assertEqual(res[1][1], 0)
|
|
|
|
self.assertEqual(res[2][0], 1)
|
|
|
|
self.assertEqual(res[2][1], 0)
|
|
|
|
self.assertEqual(res[3][0], 1)
|
|
|
|
self.assertEqual(res[3][1], 1)
|
|
|
|
|
2016-09-12 19:35:42 -04:00
|
|
|
def test_udf_with_order_by_and_limit(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
my_copy = udf(lambda x: x, IntegerType())
|
|
|
|
df = self.spark.range(10).orderBy("id")
|
|
|
|
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
|
|
|
|
res.explain(True)
|
|
|
|
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
|
|
|
|
|
2017-02-16 23:51:19 -05:00
|
|
|
def test_wholefile_json(self):
|
|
|
|
people1 = self.spark.read.json("python/test_support/sql/people.json")
|
|
|
|
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
|
|
|
|
wholeFile=True)
|
|
|
|
self.assertEqual(people1.collect(), people_array.collect())
|
|
|
|
|
2017-02-28 16:34:33 -05:00
|
|
|
def test_wholefile_csv(self):
|
|
|
|
ages_newlines = self.spark.read.csv(
|
|
|
|
"python/test_support/sql/ages_newlines.csv", wholeFile=True)
|
|
|
|
expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
|
|
|
|
Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
|
|
|
|
Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
|
|
|
|
self.assertEqual(ages_newlines.collect(), expected)
|
|
|
|
|
[SPARK-18579][SQL] Use ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options in CSV writing
## What changes were proposed in this pull request?
This PR proposes to support _not_ trimming the white spaces when writing out. These are `false` by default in CSV reading path but these are `true` by default in CSV writing in univocity parser.
Both `ignoreLeadingWhiteSpace` and `ignoreTrailingWhiteSpace` options are not being used for writing and therefore, we are always trimming the white spaces.
It seems we should provide a way to keep this white spaces easily.
WIth the data below:
```scala
val df = spark.read.csv(Seq("a , b , c").toDS)
df.show()
```
```
+---+----+---+
|_c0| _c1|_c2|
+---+----+---+
| a | b | c|
+---+----+---+
```
**Before**
```scala
df.write.csv("/tmp/text.csv")
spark.read.text("/tmp/text.csv").show()
```
```
+-----+
|value|
+-----+
|a,b,c|
+-----+
```
It seems this can't be worked around via `quoteAll` too.
```scala
df.write.option("quoteAll", true).csv("/tmp/text.csv")
spark.read.text("/tmp/text.csv").show()
```
```
+-----------+
| value|
+-----------+
|"a","b","c"|
+-----------+
```
**After**
```scala
df.write.option("ignoreLeadingWhiteSpace", false).option("ignoreTrailingWhiteSpace", false).csv("/tmp/text.csv")
spark.read.text("/tmp/text.csv").show()
```
```
+----------+
| value|
+----------+
|a , b , c|
+----------+
```
Note that this case is possible in R
```r
> system("cat text.csv")
f1,f2,f3
a , b , c
> df <- read.csv(file="text.csv")
> df
f1 f2 f3
1 a b c
> write.csv(df, file="text1.csv", quote=F, row.names=F)
> system("cat text1.csv")
f1,f2,f3
a , b , c
```
## How was this patch tested?
Unit tests in `CSVSuite` and manual tests for Python.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #17310 from HyukjinKwon/SPARK-18579.
2017-03-23 03:25:01 -04:00
|
|
|
def test_ignorewhitespace_csv(self):
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv(
|
|
|
|
tmpPath,
|
|
|
|
ignoreLeadingWhiteSpace=False,
|
|
|
|
ignoreTrailingWhiteSpace=False)
|
|
|
|
|
|
|
|
expected = [Row(value=u' a,b , c ')]
|
|
|
|
readback = self.spark.read.text(tmpPath)
|
|
|
|
self.assertEqual(readback.collect(), expected)
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2017-03-09 14:44:34 -05:00
|
|
|
def test_read_multiple_orc_file(self):
|
|
|
|
df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0",
|
|
|
|
"python/test_support/sql/orc_partitioned/b=1/c=1"])
|
|
|
|
self.assertEqual(2, df.count())
|
|
|
|
|
2016-12-08 10:22:18 -05:00
|
|
|
def test_udf_with_input_file_name(self):
|
|
|
|
from pyspark.sql.functions import udf, input_file_name
|
|
|
|
from pyspark.sql.types import StringType
|
|
|
|
sourceFile = udf(lambda path: path, StringType())
|
|
|
|
filePath = "python/test_support/sql/people1.json"
|
|
|
|
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
|
|
|
|
self.assertTrue(row[0].find("people1.json") != -1)
|
|
|
|
|
[SPARK-19223][SQL][PYSPARK] Fix InputFileBlockHolder for datasources which are based on HadoopRDD or NewHadoopRDD
## What changes were proposed in this pull request?
For some datasources which are based on HadoopRDD or NewHadoopRDD, such as spark-xml, InputFileBlockHolder doesn't work with Python UDF.
The method to reproduce it is, running the following codes with `bin/pyspark --packages com.databricks:spark-xml_2.11:0.4.1`:
from pyspark.sql.functions import udf,input_file_name
from pyspark.sql.types import StringType
from pyspark.sql import SparkSession
def filename(path):
return path
session = SparkSession.builder.appName('APP').getOrCreate()
session.udf.register('sameText', filename)
sameText = udf(filename, StringType())
df = session.read.format('xml').load('a.xml', rowTag='root').select('*', input_file_name().alias('file'))
df.select('file').show() # works
df.select(sameText(df['file'])).show() # returns empty content
The issue is because in `HadoopRDD` and `NewHadoopRDD` we set the file block's info in `InputFileBlockHolder` before the returned iterator begins consuming. `InputFileBlockHolder` will record this info into thread local variable. When running Python UDF in batch, we set up another thread to consume the iterator from child plan's output rdd, so we can't read the info back in another thread.
To fix this, we have to set the info in `InputFileBlockHolder` after the iterator begins consuming. So the info can be read in correct thread.
## How was this patch tested?
Manual test with above example codes for spark-xml package on pyspark: `bin/pyspark --packages com.databricks:spark-xml_2.11:0.4.1`.
Added pyspark test.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #16585 from viirya/fix-inputfileblock-hadooprdd.
2017-01-18 10:06:44 -05:00
|
|
|
def test_udf_with_input_file_name_for_hadooprdd(self):
|
|
|
|
from pyspark.sql.functions import udf, input_file_name
|
|
|
|
from pyspark.sql.types import StringType
|
|
|
|
|
|
|
|
def filename(path):
|
|
|
|
return path
|
|
|
|
|
|
|
|
sameText = udf(filename, StringType())
|
|
|
|
|
|
|
|
rdd = self.sc.textFile('python/test_support/sql/people.json')
|
|
|
|
df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
|
|
|
|
row = df.select(sameText(df['file'])).first()
|
|
|
|
self.assertTrue(row[0].find("people.json") != -1)
|
|
|
|
|
|
|
|
rdd2 = self.sc.newAPIHadoopFile(
|
|
|
|
'python/test_support/sql/people.json',
|
|
|
|
'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
|
|
|
|
'org.apache.hadoop.io.LongWritable',
|
|
|
|
'org.apache.hadoop.io.Text')
|
|
|
|
|
|
|
|
df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
|
|
|
|
row2 = df2.select(sameText(df2['file'])).first()
|
|
|
|
self.assertTrue(row2[0].find("people.json") != -1)
|
|
|
|
|
2017-01-31 21:03:39 -05:00
|
|
|
def test_udf_defers_judf_initalization(self):
|
|
|
|
# This is separate of UDFInitializationTests
|
|
|
|
# to avoid context initialization
|
|
|
|
# when udf is called
|
|
|
|
|
|
|
|
from pyspark.sql.functions import UserDefinedFunction
|
|
|
|
|
|
|
|
f = UserDefinedFunction(lambda x: x, StringType())
|
|
|
|
|
|
|
|
self.assertIsNone(
|
|
|
|
f._judf_placeholder,
|
|
|
|
"judf should not be initialized before the first call."
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
|
|
|
|
|
|
|
|
self.assertIsNotNone(
|
|
|
|
f._judf_placeholder,
|
|
|
|
"judf should be initialized after UDF has been called."
|
|
|
|
)
|
|
|
|
|
2017-02-13 13:37:34 -05:00
|
|
|
def test_udf_with_string_return_type(self):
|
|
|
|
from pyspark.sql.functions import UserDefinedFunction
|
|
|
|
|
|
|
|
add_one = UserDefinedFunction(lambda x: x + 1, "integer")
|
|
|
|
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
|
|
|
|
make_array = UserDefinedFunction(
|
|
|
|
lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
|
|
|
|
|
|
|
|
expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
|
|
|
|
actual = (self.spark.range(1, 2).toDF("x")
|
|
|
|
.select(add_one("x"), make_pair("x"), make_array("x"))
|
|
|
|
.first())
|
|
|
|
|
|
|
|
self.assertTupleEqual(expected, actual)
|
|
|
|
|
2017-02-14 12:46:22 -05:00
|
|
|
def test_udf_shouldnt_accept_noncallable_object(self):
|
|
|
|
from pyspark.sql.functions import UserDefinedFunction
|
|
|
|
from pyspark.sql.types import StringType
|
|
|
|
|
|
|
|
non_callable = None
|
|
|
|
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
|
|
|
|
|
2017-02-15 13:16:34 -05:00
|
|
|
def test_udf_with_decorator(self):
|
|
|
|
from pyspark.sql.functions import lit, udf
|
|
|
|
from pyspark.sql.types import IntegerType, DoubleType
|
|
|
|
|
|
|
|
@udf(IntegerType())
|
|
|
|
def add_one(x):
|
|
|
|
if x is not None:
|
|
|
|
return x + 1
|
|
|
|
|
|
|
|
@udf(returnType=DoubleType())
|
|
|
|
def add_two(x):
|
|
|
|
if x is not None:
|
|
|
|
return float(x + 2)
|
|
|
|
|
|
|
|
@udf
|
|
|
|
def to_upper(x):
|
|
|
|
if x is not None:
|
|
|
|
return x.upper()
|
|
|
|
|
|
|
|
@udf()
|
|
|
|
def to_lower(x):
|
|
|
|
if x is not None:
|
|
|
|
return x.lower()
|
|
|
|
|
|
|
|
@udf
|
|
|
|
def substr(x, start, end):
|
|
|
|
if x is not None:
|
|
|
|
return x[start:end]
|
|
|
|
|
|
|
|
@udf("long")
|
|
|
|
def trunc(x):
|
|
|
|
return int(x)
|
|
|
|
|
|
|
|
@udf(returnType="double")
|
|
|
|
def as_double(x):
|
|
|
|
return float(x)
|
|
|
|
|
|
|
|
df = (
|
|
|
|
self.spark
|
|
|
|
.createDataFrame(
|
|
|
|
[(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
|
|
|
|
.select(
|
|
|
|
add_one("one"), add_two("one"),
|
|
|
|
to_upper("Foo"), to_lower("Foo"),
|
|
|
|
substr("foobar", lit(0), lit(3)),
|
|
|
|
trunc("float"), as_double("one")))
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
[tpe for _, tpe in df.dtypes],
|
|
|
|
["int", "double", "string", "string", "string", "bigint", "double"]
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
list(df.first()),
|
|
|
|
[2, 3.0, "FOO", "foo", "foo", 3, 1.0]
|
|
|
|
)
|
|
|
|
|
2017-02-24 11:22:30 -05:00
|
|
|
def test_udf_wrapper(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
from pyspark.sql.types import IntegerType
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
"""Identity"""
|
|
|
|
return x
|
|
|
|
|
|
|
|
return_type = IntegerType()
|
|
|
|
f_ = udf(f, return_type)
|
|
|
|
|
|
|
|
self.assertTrue(f.__doc__ in f_.__doc__)
|
|
|
|
self.assertEqual(f, f_.func)
|
|
|
|
self.assertEqual(return_type, f_.returnType)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_basic_functions(self):
|
|
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.read.json(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
df.count()
|
|
|
|
df.collect()
|
2015-02-14 02:03:22 -05:00
|
|
|
df.schema
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
# cache and checkpoint
|
|
|
|
self.assertFalse(df.is_cached)
|
|
|
|
df.persist()
|
2016-04-19 20:29:28 -04:00
|
|
|
df.unpersist(True)
|
2015-02-03 19:01:56 -05:00
|
|
|
df.cache()
|
|
|
|
self.assertTrue(df.is_cached)
|
|
|
|
self.assertEqual(2, df.count())
|
|
|
|
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("temp")
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.sql("select foo from temp")
|
2015-02-03 19:01:56 -05:00
|
|
|
df.count()
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
def test_apply_schema_to_row(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
|
|
|
|
df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(df.collect(), df2.collect())
|
|
|
|
|
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
|
2016-05-11 14:24:16 -04:00
|
|
|
df3 = self.spark.createDataFrame(rdd, df.schema)
|
2015-12-30 14:14:47 -05:00
|
|
|
self.assertEqual(10, df3.count())
|
|
|
|
|
|
|
|
def test_infer_schema_to_local(self):
|
|
|
|
input = [{"a": 1}, {"b": "coffee"}]
|
|
|
|
rdd = self.sc.parallelize(input)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(input)
|
|
|
|
df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
|
2015-12-30 14:14:47 -05:00
|
|
|
self.assertEqual(df.schema, df2.schema)
|
|
|
|
|
2016-01-03 20:04:35 -05:00
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
|
2016-05-11 14:24:16 -04:00
|
|
|
df3 = self.spark.createDataFrame(rdd, df.schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(10, df3.count())
|
|
|
|
|
2016-08-15 15:41:27 -04:00
|
|
|
def test_apply_schema_to_dict_and_rows(self):
|
|
|
|
schema = StructType().add("b", StringType()).add("a", IntegerType())
|
|
|
|
input = [{"a": 1}, {"b": "coffee"}]
|
|
|
|
rdd = self.sc.parallelize(input)
|
|
|
|
for verify in [False, True]:
|
|
|
|
df = self.spark.createDataFrame(input, schema, verifySchema=verify)
|
|
|
|
df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
|
|
|
|
self.assertEqual(df.schema, df2.schema)
|
|
|
|
|
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
|
|
|
|
df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
|
|
|
|
self.assertEqual(10, df3.count())
|
|
|
|
input = [Row(a=x, b=str(x)) for x in range(10)]
|
|
|
|
df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
|
|
|
|
self.assertEqual(10, df4.count())
|
|
|
|
|
2016-01-24 22:40:34 -05:00
|
|
|
def test_create_dataframe_schema_mismatch(self):
|
|
|
|
input = [Row(a=1)]
|
|
|
|
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
|
|
|
|
schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd, schema)
|
2016-03-08 17:00:03 -05:00
|
|
|
self.assertRaises(Exception, lambda: df.show())
|
2016-01-24 22:40:34 -05:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_serialize_nested_array_and_map(self):
|
|
|
|
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
row = df.head()
|
|
|
|
self.assertEqual(1, len(row.l))
|
|
|
|
self.assertEqual(1, row.l[0].a)
|
|
|
|
self.assertEqual("2", row.d["key"].d)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
l = df.rdd.map(lambda x: x.l).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, len(l))
|
|
|
|
self.assertEqual('s', l[0].b)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
d = df.rdd.map(lambda x: x.d).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, len(d))
|
|
|
|
self.assertEqual(1.0, d["key"].c)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
row = df.rdd.map(lambda x: x.d["key"]).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1.0, row.c)
|
|
|
|
self.assertEqual("2", row.d)
|
|
|
|
|
|
|
|
def test_infer_schema(self):
|
2015-02-20 18:35:05 -05:00
|
|
|
d = [Row(l=[], d={}, s=None),
|
2015-02-03 19:01:56 -05:00
|
|
|
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd)
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual([], df.rdd.map(lambda r: r.l).first())
|
|
|
|
self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("test")
|
2016-05-11 14:24:16 -04:00
|
|
|
result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'")
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, result.head()[0])
|
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
|
2015-02-14 02:03:22 -05:00
|
|
|
self.assertEqual(df.schema, df2.schema)
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
|
|
|
|
self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
|
2016-05-17 21:01:59 -04:00
|
|
|
df2.createOrReplaceTempView("test2")
|
2016-05-11 14:24:16 -04:00
|
|
|
result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, result.head()[0])
|
|
|
|
|
2015-02-24 23:51:55 -05:00
|
|
|
def test_infer_nested_schema(self):
|
|
|
|
NestedRow = Row("f1", "f2")
|
|
|
|
nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
|
|
|
|
NestedRow([2, 3], {"row2": 2.0})])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(nestedRdd1)
|
2015-02-24 23:51:55 -05:00
|
|
|
self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
|
|
|
|
|
|
|
|
nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
|
|
|
|
NestedRow([[2, 3], [3, 4]], [2, 3])])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(nestedRdd2)
|
2015-02-24 23:51:55 -05:00
|
|
|
self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
|
|
|
|
|
|
|
|
from collections import namedtuple
|
|
|
|
CustomRow = namedtuple('CustomRow', 'field1 field2')
|
|
|
|
rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
|
|
|
|
CustomRow(field1=2, field2="row2"),
|
|
|
|
CustomRow(field1=3, field2="row3")])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd)
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
|
2015-02-24 23:51:55 -05:00
|
|
|
|
2015-08-26 19:04:44 -04:00
|
|
|
def test_create_dataframe_from_objects(self):
|
|
|
|
data = [MyObject(1, "1"), MyObject(2, "2")]
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(data)
|
2015-08-26 19:04:44 -04:00
|
|
|
self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
|
|
|
|
self.assertEqual(df.first(), Row(key=1, value="1"))
|
|
|
|
|
2015-07-20 15:00:48 -04:00
|
|
|
def test_select_null_literal(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.sql("select null as col")
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(Row(col=None), df.first())
|
2015-07-20 15:00:48 -04:00
|
|
|
|
2015-02-24 23:51:55 -05:00
|
|
|
def test_apply_schema(self):
|
|
|
|
from datetime import date, datetime
|
2015-04-16 19:20:57 -04:00
|
|
|
rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
|
2015-02-24 23:51:55 -05:00
|
|
|
date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
|
|
|
|
{"a": 1}, (2,), [1, 2, 3], None)])
|
|
|
|
schema = StructType([
|
|
|
|
StructField("byte1", ByteType(), False),
|
|
|
|
StructField("byte2", ByteType(), False),
|
|
|
|
StructField("short1", ShortType(), False),
|
|
|
|
StructField("short2", ShortType(), False),
|
|
|
|
StructField("int1", IntegerType(), False),
|
|
|
|
StructField("float1", FloatType(), False),
|
|
|
|
StructField("date1", DateType(), False),
|
|
|
|
StructField("time1", TimestampType(), False),
|
|
|
|
StructField("map1", MapType(StringType(), IntegerType(), False), False),
|
|
|
|
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
|
|
|
|
StructField("list1", ArrayType(ByteType(), False), False),
|
|
|
|
StructField("null1", DoubleType(), True)])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd, schema)
|
2016-03-02 18:26:34 -05:00
|
|
|
results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
|
|
|
|
x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
|
2015-02-24 23:51:55 -05:00
|
|
|
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
|
|
|
|
datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
|
|
|
|
self.assertEqual(r, results.first())
|
|
|
|
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("table2")
|
2016-05-11 14:24:16 -04:00
|
|
|
r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
|
|
|
|
"short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
|
|
|
|
"float1 + 1.5 as float1 FROM table2").first()
|
2015-02-24 23:51:55 -05:00
|
|
|
|
|
|
|
self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
|
|
|
|
|
|
|
|
from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
|
|
|
|
rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
|
|
|
|
{"a": 1}, (2,), [1, 2, 3])])
|
|
|
|
abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
|
|
|
|
schema = _parse_schema_abstract(abstract)
|
|
|
|
typedSchema = _infer_schema_type(rdd.first(), schema)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame(rdd, typedSchema)
|
2015-02-24 23:51:55 -05:00
|
|
|
r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
|
|
|
|
self.assertEqual(r, tuple(df.first()))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_struct_in_map(self):
|
|
|
|
d = [Row(m={Row(i=1): Row(s="")})]
|
2015-02-14 02:03:22 -05:00
|
|
|
df = self.sc.parallelize(d).toDF()
|
2015-04-16 19:20:57 -04:00
|
|
|
k, v = list(df.head().m.items())[0]
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, k.i)
|
|
|
|
self.assertEqual("", v.s)
|
|
|
|
|
|
|
|
def test_convert_row_to_dict(self):
|
|
|
|
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
|
|
|
|
self.assertEqual(1, row.asDict()['l'][0].a)
|
2015-02-14 02:03:22 -05:00
|
|
|
df = self.sc.parallelize([row]).toDF()
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("test")
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.sql("select l, d from test").head()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, row.asDict()["l"][0].a)
|
|
|
|
self.assertEqual(1.0, row.asDict()['d']['key'].c)
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
def test_udt(self):
|
|
|
|
from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
|
|
|
|
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
|
|
|
|
|
|
|
|
def check_datatype(datatype):
|
|
|
|
pickled = pickle.loads(pickle.dumps(datatype))
|
|
|
|
assert datatype == pickled
|
2016-08-25 02:36:04 -04:00
|
|
|
scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
|
2015-07-30 01:30:49 -04:00
|
|
|
python_datatype = _parse_datatype_json_string(scala_datatype.json())
|
|
|
|
assert datatype == python_datatype
|
|
|
|
|
|
|
|
check_datatype(ExamplePointUDT())
|
|
|
|
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
|
|
|
check_datatype(structtype_with_udt)
|
|
|
|
p = ExamplePoint(1.0, 2.0)
|
|
|
|
self.assertEqual(_infer_type(p), ExamplePointUDT())
|
|
|
|
_verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
|
|
|
|
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
|
|
|
|
|
|
|
|
check_datatype(PythonOnlyUDT())
|
|
|
|
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", PythonOnlyUDT(), False)])
|
|
|
|
check_datatype(structtype_with_udt)
|
|
|
|
p = PythonOnlyPoint(1.0, 2.0)
|
|
|
|
self.assertEqual(_infer_type(p), PythonOnlyUDT())
|
|
|
|
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
|
|
|
|
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
|
|
|
|
|
[SPARK-16062] [SPARK-15989] [SQL] Fix two bugs of Python-only UDTs
## What changes were proposed in this pull request?
There are two related bugs of Python-only UDTs. Because the test case of second one needs the first fix too. I put them into one PR. If it is not appropriate, please let me know.
### First bug: When MapObjects works on Python-only UDTs
`RowEncoder` will use `PythonUserDefinedType.sqlType` for its deserializer expression. If the sql type is `ArrayType`, we will have `MapObjects` working on it. But `MapObjects` doesn't consider `PythonUserDefinedType` as its input data type. It causes error like:
import pyspark.sql.group
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql.types import *
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
df = spark.createDataFrame([(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema)
df.show()
File "/home/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o36.showString.
: java.lang.RuntimeException: Error while decoding: scala.MatchError: org.apache.spark.sql.types.PythonUserDefinedTypef4ceede8 (of class org.apache.spark.sql.types.PythonUserDefinedType)
...
### Second bug: When Python-only UDTs is the element type of ArrayType
import pyspark.sql.group
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql.types import *
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
df = spark.createDataFrame([(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], schema=schema)
df.show()
## How was this patch tested?
PySpark's sql tests.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #13778 from viirya/fix-pyudt.
2016-08-02 13:08:18 -04:00
|
|
|
def test_simple_udt_in_df(self):
|
|
|
|
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
|
|
|
|
schema=schema)
|
|
|
|
df.show()
|
|
|
|
|
|
|
|
def test_nested_udt_in_df(self):
|
|
|
|
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
|
|
|
|
schema=schema)
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
schema = StructType().add("key", LongType()).add("val",
|
|
|
|
MapType(LongType(), PythonOnlyUDT()))
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
|
|
|
|
schema=schema)
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
def test_complex_nested_udt_in_df(self):
|
|
|
|
from pyspark.sql.functions import udf
|
|
|
|
|
|
|
|
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
|
|
|
|
schema=schema)
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
gd = df.groupby("key").agg({"val": "collect_list"})
|
|
|
|
gd.collect()
|
|
|
|
udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
|
|
|
|
gd.select(udf(*gd)).collect()
|
|
|
|
|
2016-06-28 17:09:38 -04:00
|
|
|
def test_udt_with_none(self):
|
|
|
|
df = self.spark.range(0, 10, 1, 1)
|
|
|
|
|
|
|
|
def myudf(x):
|
|
|
|
if x > 0:
|
|
|
|
return PythonOnlyPoint(float(x), float(x))
|
|
|
|
|
|
|
|
self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
|
|
|
|
rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
|
|
|
|
self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_infer_schema_with_udt(self):
|
2015-02-09 23:49:22 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-02-14 02:03:22 -05:00
|
|
|
schema = df.schema
|
2015-02-03 19:01:56 -05:00
|
|
|
field = [f for f in schema.fields if f.name == "point"][0]
|
|
|
|
self.assertEqual(type(field.dataType), ExamplePointUDT)
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("labeled_point")
|
2016-05-11 14:24:16 -04:00
|
|
|
point = self.spark.sql("SELECT point FROM labeled_point").head().point
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-07-30 01:30:49 -04:00
|
|
|
schema = df.schema
|
|
|
|
field = [f for f in schema.fields if f.name == "point"][0]
|
|
|
|
self.assertEqual(type(field.dataType), PythonOnlyUDT)
|
2016-05-17 21:01:59 -04:00
|
|
|
df.createOrReplaceTempView("labeled_point")
|
2016-05-11 14:24:16 -04:00
|
|
|
point = self.spark.sql("SELECT point FROM labeled_point").head().point
|
2015-07-30 01:30:49 -04:00
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_apply_schema_with_udt(self):
|
2015-02-09 23:49:22 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = (1.0, ExamplePoint(1.0, 2.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row], schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
point = df.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = (1.0, PythonOnlyPoint(1.0, 2.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", PythonOnlyUDT(), False)])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row], schema)
|
2015-07-30 01:30:49 -04:00
|
|
|
point = df.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
|
2015-07-09 17:43:38 -04:00
|
|
|
def test_udf_with_udt(self):
|
2015-07-20 15:14:47 -04:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-07-09 17:43:38 -04:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
|
2015-07-09 17:43:38 -04:00
|
|
|
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
|
|
|
|
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
|
2015-07-20 15:14:47 -04:00
|
|
|
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
|
|
|
|
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
|
2015-07-09 17:43:38 -04:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
|
2015-07-30 01:30:49 -04:00
|
|
|
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
|
|
|
|
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
|
|
|
|
udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
|
|
|
|
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_parquet_with_udt(self):
|
2015-07-30 01:30:49 -04:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df0 = self.spark.createDataFrame([row])
|
2015-02-03 19:01:56 -05:00
|
|
|
output_dir = os.path.join(self.tempdir.name, "labeled_point")
|
2015-07-30 01:30:49 -04:00
|
|
|
df0.write.parquet(output_dir)
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.read.parquet(output_dir)
|
2015-02-03 19:01:56 -05:00
|
|
|
point = df1.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
2016-05-11 14:24:16 -04:00
|
|
|
df0 = self.spark.createDataFrame([row])
|
2015-07-30 01:30:49 -04:00
|
|
|
df0.write.parquet(output_dir, mode='overwrite')
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.read.parquet(output_dir)
|
2015-07-30 01:30:49 -04:00
|
|
|
point = df1.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
|
2016-03-25 01:34:55 -04:00
|
|
|
def test_union_with_udt(self):
|
2016-02-21 19:58:17 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
|
|
|
row1 = (1.0, ExamplePoint(1.0, 2.0))
|
|
|
|
row2 = (2.0, ExamplePoint(3.0, 4.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.createDataFrame([row1], schema)
|
|
|
|
df2 = self.spark.createDataFrame([row2], schema)
|
2016-02-21 19:58:17 -05:00
|
|
|
|
2016-03-25 01:34:55 -04:00
|
|
|
result = df1.union(df2).orderBy("label").collect()
|
2016-02-21 19:58:17 -05:00
|
|
|
self.assertEqual(
|
|
|
|
result,
|
|
|
|
[
|
|
|
|
Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
|
|
|
|
Row(label=2.0, point=ExamplePoint(3.0, 4.0))
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_column_operators(self):
|
|
|
|
ci = self.df.key
|
|
|
|
cs = self.df.value
|
|
|
|
c = ci == cs
|
|
|
|
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
|
2015-09-11 18:19:04 -04:00
|
|
|
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in rcc))
|
2015-06-23 18:51:16 -04:00
|
|
|
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in cb))
|
|
|
|
cbool = (ci & ci), (ci | ci), (~ci)
|
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in cbool))
|
2017-02-23 16:22:39 -05:00
|
|
|
css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\
|
|
|
|
cs.startswith('a'), cs.endswith('a')
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in css))
|
|
|
|
self.assertTrue(isinstance(ci.cast(LongType()), Column))
|
[SPARK-19701][SQL][PYTHON] Throws a correct exception for 'in' operator against column
## What changes were proposed in this pull request?
This PR proposes to remove incorrect implementation that has been not executed so far (at least from Spark 1.5.2) for `in` operator and throw a correct exception rather than saying it is a bool. I tested the codes above in 1.5.2, 1.6.3, 2.1.0 and in the master branch as below:
**1.5.2**
```python
>>> df = sqlContext.createDataFrame([[1]])
>>> 1 in df._1
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark-1.5.2-bin-hadoop2.6/python/pyspark/sql/column.py", line 418, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
**1.6.3**
```python
>>> 1 in sqlContext.range(1).id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark-1.6.3-bin-hadoop2.6/python/pyspark/sql/column.py", line 447, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
**2.1.0**
```python
>>> 1 in spark.range(1).id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark-2.1.0-bin-hadoop2.7/python/pyspark/sql/column.py", line 426, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
**Current Master**
```python
>>> 1 in spark.range(1).id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/column.py", line 452, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
**After**
```python
>>> 1 in spark.range(1).id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/column.py", line 184, in __contains__
raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' "
ValueError: Cannot apply 'in' operator against a column: please use 'contains' in a string column or 'array_contains' function for an array column.
```
In more details,
It seems the implementation intended to support this
```python
1 in df.column
```
However, currently, it throws an exception as below:
```python
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/column.py", line 426, in __nonzero__
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
ValueError: Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
```
What happens here is as below:
```python
class Column(object):
def __contains__(self, item):
print "I am contains"
return Column()
def __nonzero__(self):
raise Exception("I am nonzero.")
>>> 1 in Column()
I am contains
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 6, in __nonzero__
Exception: I am nonzero.
```
It seems it calls `__contains__` first and then `__nonzero__` or `__bool__` is being called against `Column()` to make this a bool (or int to be specific).
It seems `__nonzero__` (for Python 2), `__bool__` (for Python 3) and `__contains__` forcing the the return into a bool unlike other operators. There are few references about this as below:
https://bugs.python.org/issue16011
http://stackoverflow.com/questions/12244074/python-source-code-for-built-in-in-operator/12244378#12244378
http://stackoverflow.com/questions/38542543/functionality-of-python-in-vs-contains/38542777
It seems we can't overwrite `__nonzero__` or `__bool__` as a workaround to make this working because these force the return type as a bool as below:
```python
class Column(object):
def __contains__(self, item):
print "I am contains"
return Column()
def __nonzero__(self):
return "a"
>>> 1 in Column()
I am contains
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: __nonzero__ should return bool or int, returned str
```
## How was this patch tested?
Added unit tests in `tests.py`.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #17160 from HyukjinKwon/SPARK-19701.
2017-03-05 21:04:52 -05:00
|
|
|
self.assertRaisesRegexp(ValueError,
|
|
|
|
"Cannot apply 'in' operator against a column",
|
|
|
|
lambda: 1 in cs)
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2017-02-13 18:23:56 -05:00
|
|
|
def test_column_getitem(self):
|
|
|
|
from pyspark.sql.functions import col
|
|
|
|
|
|
|
|
self.assertIsInstance(col("foo")[1:3], Column)
|
|
|
|
self.assertIsInstance(col("foo")[0], Column)
|
|
|
|
self.assertIsInstance(col("foo")["bar"], Column)
|
|
|
|
self.assertRaises(ValueError, lambda: col("foo")[0:10:2])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_column_select(self):
|
|
|
|
df = self.df
|
|
|
|
self.assertEqual(self.testData, df.select("*").collect())
|
|
|
|
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
|
|
|
|
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
|
|
|
|
|
2015-05-02 02:43:24 -04:00
|
|
|
def test_freqItems(self):
|
|
|
|
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
|
|
|
|
df = self.sc.parallelize(vals).toDF()
|
|
|
|
items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
|
|
|
|
self.assertTrue(1 in items[0])
|
|
|
|
self.assertTrue(-2.0 in items[1])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_aggregator(self):
|
|
|
|
df = self.df
|
|
|
|
g = df.groupBy()
|
|
|
|
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
|
|
|
|
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
|
|
|
|
|
2015-02-14 02:03:22 -05:00
|
|
|
from pyspark.sql import functions
|
|
|
|
self.assertEqual((0, u'99'),
|
|
|
|
tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
|
|
|
|
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
|
|
|
|
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2016-01-31 16:56:13 -05:00
|
|
|
def test_first_last_ignorenulls(self):
|
|
|
|
from pyspark.sql import functions
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.range(0, 100)
|
2016-01-31 16:56:13 -05:00
|
|
|
df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
|
|
|
|
df3 = df2.select(functions.first(df2.id, False).alias('a'),
|
|
|
|
functions.first(df2.id, True).alias('b'),
|
|
|
|
functions.last(df2.id, False).alias('c'),
|
|
|
|
functions.last(df2.id, True).alias('d'))
|
|
|
|
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
|
|
|
|
|
2016-02-25 02:15:36 -05:00
|
|
|
def test_approxQuantile(self):
|
2017-02-01 17:11:28 -05:00
|
|
|
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
|
2016-02-25 02:15:36 -05:00
|
|
|
aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
|
|
|
|
self.assertTrue(isinstance(aq, list))
|
|
|
|
self.assertEqual(len(aq), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aq))
|
2017-02-01 17:11:28 -05:00
|
|
|
aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
|
|
|
|
self.assertTrue(isinstance(aqs, list))
|
|
|
|
self.assertEqual(len(aqs), 2)
|
|
|
|
self.assertTrue(isinstance(aqs[0], list))
|
|
|
|
self.assertEqual(len(aqs[0]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
|
|
|
|
self.assertTrue(isinstance(aqs[1], list))
|
|
|
|
self.assertEqual(len(aqs[1]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
|
|
|
|
aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
|
|
|
|
self.assertTrue(isinstance(aqt, list))
|
|
|
|
self.assertEqual(len(aqt), 2)
|
|
|
|
self.assertTrue(isinstance(aqt[0], list))
|
|
|
|
self.assertEqual(len(aqt[0]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
|
|
|
|
self.assertTrue(isinstance(aqt[1], list))
|
|
|
|
self.assertEqual(len(aqt[1]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
|
2016-02-25 02:15:36 -05:00
|
|
|
|
2015-05-04 00:44:39 -04:00
|
|
|
def test_corr(self):
|
|
|
|
import math
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
|
|
|
|
corr = df.stat.corr("a", "b")
|
|
|
|
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
|
|
|
|
|
2015-05-01 16:29:17 -04:00
|
|
|
def test_cov(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
|
|
|
cov = df.stat.cov("a", "b")
|
|
|
|
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
|
|
|
|
|
2015-05-04 20:02:49 -04:00
|
|
|
def test_crosstab(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
|
|
|
|
ct = df.stat.crosstab("a", "b").collect()
|
|
|
|
ct = sorted(ct, key=lambda x: x[0])
|
|
|
|
for i, row in enumerate(ct):
|
|
|
|
self.assertEqual(row[0], str(i))
|
|
|
|
self.assertTrue(row[1], 1)
|
|
|
|
self.assertTrue(row[2], 1)
|
|
|
|
|
2015-04-29 03:09:24 -04:00
|
|
|
def test_math_functions(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
2015-05-06 01:56:01 -04:00
|
|
|
from pyspark.sql import functions
|
2015-04-29 03:09:24 -04:00
|
|
|
import math
|
|
|
|
|
|
|
|
def get_values(l):
|
|
|
|
return [j[0] for j in l]
|
|
|
|
|
|
|
|
def assert_close(a, b):
|
|
|
|
c = get_values(b)
|
|
|
|
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
|
|
|
|
return sum(diff) == len(a)
|
|
|
|
assert_close([math.cos(i) for i in range(10)],
|
|
|
|
df.select(functions.cos(df.a)).collect())
|
|
|
|
assert_close([math.cos(i) for i in range(10)],
|
|
|
|
df.select(functions.cos("a")).collect())
|
|
|
|
assert_close([math.sin(i) for i in range(10)],
|
|
|
|
df.select(functions.sin(df.a)).collect())
|
|
|
|
assert_close([math.sin(i) for i in range(10)],
|
|
|
|
df.select(functions.sin(df['a'])).collect())
|
|
|
|
assert_close([math.pow(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, df.b)).collect())
|
|
|
|
assert_close([math.pow(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, 2)).collect())
|
|
|
|
assert_close([math.pow(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, 2.0)).collect())
|
|
|
|
assert_close([math.hypot(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.hypot(df.a, df.b)).collect())
|
|
|
|
|
2015-05-01 00:56:03 -04:00
|
|
|
def test_rand_functions(self):
|
|
|
|
df = self.df
|
|
|
|
from pyspark.sql import functions
|
|
|
|
rnd = df.select('key', functions.rand()).collect()
|
|
|
|
for row in rnd:
|
|
|
|
assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
|
|
|
|
rndn = df.select('key', functions.randn(5)).collect()
|
|
|
|
for row in rndn:
|
|
|
|
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
|
|
|
|
|
2015-08-06 20:03:14 -04:00
|
|
|
# If the specified seed is 0, we should use it.
|
|
|
|
# https://issues.apache.org/jira/browse/SPARK-9691
|
|
|
|
rnd1 = df.select('key', functions.rand(0)).collect()
|
|
|
|
rnd2 = df.select('key', functions.rand(0)).collect()
|
|
|
|
self.assertEqual(sorted(rnd1), sorted(rnd2))
|
|
|
|
|
|
|
|
rndn1 = df.select('key', functions.randn(0)).collect()
|
|
|
|
rndn2 = df.select('key', functions.randn(0)).collect()
|
|
|
|
self.assertEqual(sorted(rndn1), sorted(rndn2))
|
|
|
|
|
2017-03-26 21:40:00 -04:00
|
|
|
def test_array_contains_function(self):
|
|
|
|
from pyspark.sql.functions import array_contains
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
|
|
|
|
actual = df.select(array_contains(df.data, 1).alias('b')).collect()
|
|
|
|
# The value argument can be implicitly castable to the element's type of the array.
|
|
|
|
self.assertEqual([Row(b=True), Row(b=False)], actual)
|
|
|
|
|
2015-05-05 16:23:53 -04:00
|
|
|
def test_between_function(self):
|
|
|
|
df = self.sc.parallelize([
|
|
|
|
Row(a=1, b=2, c=3),
|
|
|
|
Row(a=2, b=1, c=3),
|
|
|
|
Row(a=4, b=1, c=4)]).toDF()
|
|
|
|
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
|
|
|
|
df.filter(df.a.between(df.b, df.c)).collect())
|
|
|
|
|
2015-06-29 17:15:15 -04:00
|
|
|
def test_struct_type(self):
|
|
|
|
from pyspark.sql.types import StructType, StringType, StructField
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True),
|
|
|
|
StructField("f2", StringType(), True, None)])
|
|
|
|
self.assertEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True)])
|
|
|
|
self.assertNotEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = (StructType().add(StructField("f1", StringType(), True))
|
|
|
|
.add(StructField("f2", StringType(), True, None)))
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True),
|
|
|
|
StructField("f2", StringType(), True, None)])
|
|
|
|
self.assertEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = (StructType().add(StructField("f1", StringType(), True))
|
|
|
|
.add(StructField("f2", StringType(), True, None)))
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True)])
|
|
|
|
self.assertNotEqual(struct1, struct2)
|
|
|
|
|
|
|
|
# Catch exception raised during improper construction
|
2016-04-20 16:45:14 -04:00
|
|
|
with self.assertRaises(ValueError):
|
2015-06-29 17:15:15 -04:00
|
|
|
struct1 = StructType().add("name")
|
2016-04-20 16:45:14 -04:00
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
for field in struct1:
|
|
|
|
self.assertIsInstance(field, StructField)
|
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
self.assertEqual(len(struct1), 2)
|
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
self.assertIs(struct1["f1"], struct1.fields[0])
|
|
|
|
self.assertIs(struct1[0], struct1.fields[0])
|
|
|
|
self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
|
|
|
|
with self.assertRaises(KeyError):
|
|
|
|
not_a_field = struct1["f9"]
|
|
|
|
with self.assertRaises(IndexError):
|
|
|
|
not_a_field = struct1[9]
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
not_a_field = struct1[9.9]
|
2015-06-29 17:15:15 -04:00
|
|
|
|
2016-01-27 12:55:10 -05:00
|
|
|
def test_metadata_null(self):
|
|
|
|
from pyspark.sql.types import StructType, StringType, StructField
|
|
|
|
schema = StructType([StructField("f1", StringType(), True, None),
|
|
|
|
StructField("f2", StringType(), True, {'a': None})])
|
|
|
|
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.createDataFrame(rdd, schema)
|
2016-01-27 12:55:10 -05:00
|
|
|
|
2015-02-10 20:29:52 -05:00
|
|
|
def test_save_and_load(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.json(tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath, schema)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.json(tmpPath, "overwrite")
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.save(format="json", mode="overwrite", path=tmpPath,
|
|
|
|
noUse="this options will not be used in save.")
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.load(format="json", path=tmpPath,
|
|
|
|
noUse="this options will not be used in load.")
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
|
2015-02-10 20:29:52 -05:00
|
|
|
"org.apache.spark.sql.parquet")
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
|
|
|
actual = self.spark.read.load(path=tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
2015-06-22 16:51:23 -04:00
|
|
|
|
2016-04-22 12:19:36 -04:00
|
|
|
csvpath = os.path.join(tempfile.mkdtemp(), 'data')
|
|
|
|
df.write.option('quote', None).format('csv').save(csvpath)
|
|
|
|
|
2015-06-22 16:51:23 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
|
|
|
def test_save_and_load_builder(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
df.write.json(tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath, schema)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.mode("overwrite").json(tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.read.json(tmpPath)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
|
2015-06-29 03:13:39 -04:00
|
|
|
.option("noUse", "this option will not be used in save.")\
|
2015-06-22 16:51:23 -04:00
|
|
|
.format("json").save(path=tmpPath)
|
|
|
|
actual =\
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.read.format("json")\
|
|
|
|
.load(path=tmpPath, noUse="this options will not be used in load.")
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
|
2015-06-22 16:51:23 -04:00
|
|
|
"org.apache.spark.sql.parquet")
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
|
|
|
actual = self.spark.read.load(path=tmpPath)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
[SPARK-19876][SS][WIP] OneTime Trigger Executor
## What changes were proposed in this pull request?
An additional trigger and trigger executor that will execute a single trigger only. One can use this OneTime trigger to have more control over the scheduling of triggers.
In addition, this patch requires an optimization to StreamExecution that logs a commit record at the end of successfully processing a batch. This new commit log will be used to determine the next batch (offsets) to process after a restart, instead of using the offset log itself to determine what batch to process next after restart; using the offset log to determine this would process the previously logged batch, always, thus not permitting a OneTime trigger feature.
## How was this patch tested?
A number of existing tests have been revised. These tests all assumed that when restarting a stream, the last batch in the offset log is to be re-processed. Given that we now have a commit log that will tell us if that last batch was processed successfully, the results/assumptions of those tests needed to be revised accordingly.
In addition, a OneTime trigger test was added to StreamingQuerySuite, which tests:
- The semantics of OneTime trigger (i.e., on start, execute a single batch, then stop).
- The case when the commit log was not able to successfully log the completion of a batch before restart, which would mean that we should fall back to what's in the offset log.
- A OneTime trigger execution that results in an exception being thrown.
marmbrus tdas zsxwing
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Tyson Condie <tcondie@gmail.com>
Author: Tathagata Das <tathagata.das1565@gmail.com>
Closes #17219 from tcondie/stream-commit.
2017-03-23 17:32:05 -04:00
|
|
|
def test_stream_trigger(self):
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
[SPARK-19876][SS][WIP] OneTime Trigger Executor
## What changes were proposed in this pull request?
An additional trigger and trigger executor that will execute a single trigger only. One can use this OneTime trigger to have more control over the scheduling of triggers.
In addition, this patch requires an optimization to StreamExecution that logs a commit record at the end of successfully processing a batch. This new commit log will be used to determine the next batch (offsets) to process after a restart, instead of using the offset log itself to determine what batch to process next after restart; using the offset log to determine this would process the previously logged batch, always, thus not permitting a OneTime trigger feature.
## How was this patch tested?
A number of existing tests have been revised. These tests all assumed that when restarting a stream, the last batch in the offset log is to be re-processed. Given that we now have a commit log that will tell us if that last batch was processed successfully, the results/assumptions of those tests needed to be revised accordingly.
In addition, a OneTime trigger test was added to StreamingQuerySuite, which tests:
- The semantics of OneTime trigger (i.e., on start, execute a single batch, then stop).
- The case when the commit log was not able to successfully log the completion of a batch before restart, which would mean that we should fall back to what's in the offset log.
- A OneTime trigger execution that results in an exception being thrown.
marmbrus tdas zsxwing
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Tyson Condie <tcondie@gmail.com>
Author: Tathagata Das <tathagata.das1565@gmail.com>
Closes #17219 from tcondie/stream-commit.
2017-03-23 17:32:05 -04:00
|
|
|
|
|
|
|
# Should take at least one arg
|
|
|
|
try:
|
|
|
|
df.writeStream.trigger()
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# Should not take multiple args
|
|
|
|
try:
|
|
|
|
df.writeStream.trigger(once=True, processingTime='5 seconds')
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# Should take only keyword args
|
2016-04-20 13:32:01 -04:00
|
|
|
try:
|
2016-06-14 20:58:45 -04:00
|
|
|
df.writeStream.trigger('5 seconds')
|
2016-04-20 13:32:01 -04:00
|
|
|
self.fail("Should have thrown an exception")
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def test_stream_read_options(self):
|
|
|
|
schema = StructType([StructField("data", StringType(), False)])
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream\
|
|
|
|
.format('text')\
|
|
|
|
.option('path', 'python/test_support/sql/streaming')\
|
|
|
|
.schema(schema)\
|
|
|
|
.load()
|
2016-04-20 13:32:01 -04:00
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<data:string>")
|
|
|
|
|
|
|
|
def test_stream_read_options_overwrite(self):
|
|
|
|
bad_schema = StructType([StructField("test", IntegerType(), False)])
|
|
|
|
schema = StructType([StructField("data", StringType(), False)])
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \
|
|
|
|
.schema(bad_schema)\
|
|
|
|
.load(path='python/test_support/sql/streaming', schema=schema, format='text')
|
2016-04-20 13:32:01 -04:00
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<data:string>")
|
|
|
|
|
|
|
|
def test_stream_save_options(self):
|
2016-12-15 17:26:54 -05:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \
|
|
|
|
.withColumn('id', lit(1))
|
2016-06-15 13:46:02 -04:00
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
2016-04-20 13:32:01 -04:00
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
2016-06-15 13:46:02 -04:00
|
|
|
q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \
|
2016-12-15 17:26:54 -05:00
|
|
|
.format('parquet').partitionBy('id').outputMode('append').option('path', out).start()
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
self.assertEqual(q.name, 'this_query')
|
|
|
|
self.assertTrue(q.isActive)
|
|
|
|
q.processAllAvailable()
|
2016-04-28 18:22:28 -04:00
|
|
|
output_files = []
|
|
|
|
for _, _, files in os.walk(out):
|
2016-05-03 13:58:26 -04:00
|
|
|
output_files.extend([f for f in files if not f.startswith('.')])
|
2016-04-28 18:22:28 -04:00
|
|
|
self.assertTrue(len(output_files) > 0)
|
|
|
|
self.assertTrue(len(os.listdir(chk)) > 0)
|
|
|
|
finally:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
2016-04-20 13:32:01 -04:00
|
|
|
|
|
|
|
def test_stream_save_options_overwrite(self):
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
2016-06-15 13:46:02 -04:00
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
2016-04-20 13:32:01 -04:00
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
|
|
|
fake1 = os.path.join(tmpPath, 'fake1')
|
|
|
|
fake2 = os.path.join(tmpPath, 'fake2')
|
2016-06-15 13:46:02 -04:00
|
|
|
q = df.writeStream.option('checkpointLocation', fake1)\
|
2016-06-14 20:58:45 -04:00
|
|
|
.format('memory').option('path', fake2) \
|
2016-05-31 18:57:01 -04:00
|
|
|
.queryName('fake_query').outputMode('append') \
|
2016-06-14 20:58:45 -04:00
|
|
|
.start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
|
2016-05-31 18:57:01 -04:00
|
|
|
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
self.assertEqual(q.name, 'this_query')
|
|
|
|
self.assertTrue(q.isActive)
|
|
|
|
q.processAllAvailable()
|
2016-04-28 18:22:28 -04:00
|
|
|
output_files = []
|
|
|
|
for _, _, files in os.walk(out):
|
2016-05-03 13:58:26 -04:00
|
|
|
output_files.extend([f for f in files if not f.startswith('.')])
|
2016-04-28 18:22:28 -04:00
|
|
|
self.assertTrue(len(output_files) > 0)
|
|
|
|
self.assertTrue(len(os.listdir(chk)) > 0)
|
|
|
|
self.assertFalse(os.path.isdir(fake1)) # should not have been created
|
|
|
|
self.assertFalse(os.path.isdir(fake2)) # should not have been created
|
|
|
|
finally:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
2016-04-20 13:32:01 -04:00
|
|
|
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
def test_stream_status_and_progress(self):
|
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
2016-12-14 16:36:41 -05:00
|
|
|
|
|
|
|
def func(x):
|
|
|
|
time.sleep(1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
from pyspark.sql.functions import col, udf
|
|
|
|
sleep_udf = udf(func)
|
|
|
|
|
|
|
|
# Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there
|
|
|
|
# were no updates.
|
|
|
|
q = df.select(sleep_udf(col("value")).alias('value')).writeStream \
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
.start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
|
|
|
|
try:
|
2016-12-14 16:36:41 -05:00
|
|
|
# "lastProgress" will return None in most cases. However, as it may be flaky when
|
|
|
|
# Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress"
|
|
|
|
# may throw error with a high chance and make this test flaky, so we should still be
|
|
|
|
# able to detect broken codes.
|
|
|
|
q.lastProgress
|
|
|
|
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
q.processAllAvailable()
|
|
|
|
lastProgress = q.lastProgress
|
2016-12-07 18:36:29 -05:00
|
|
|
recentProgress = q.recentProgress
|
2016-11-30 02:08:56 -05:00
|
|
|
status = q.status
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
self.assertEqual(lastProgress['name'], q.name)
|
|
|
|
self.assertEqual(lastProgress['id'], q.id)
|
2016-12-07 18:36:29 -05:00
|
|
|
self.assertTrue(any(p == lastProgress for p in recentProgress))
|
2016-11-30 02:08:56 -05:00
|
|
|
self.assertTrue(
|
|
|
|
"message" in status and
|
|
|
|
"isDataAvailable" in status and
|
|
|
|
"isTriggerActive" in status)
|
[SPARK-18516][SQL] Split state and progress in streaming
This PR separates the status of a `StreamingQuery` into two separate APIs:
- `status` - describes the status of a `StreamingQuery` at this moment, including what phase of processing is currently happening and if data is available.
- `recentProgress` - an array of statistics about the most recent microbatches that have executed.
A recent progress contains the following information:
```
{
"id" : "2be8670a-fce1-4859-a530-748f29553bb6",
"name" : "query-29",
"timestamp" : 1479705392724,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303,
"durationMs" : {
"triggerExecution" : 276,
"queryPlanning" : 3,
"getBatch" : 5,
"getOffset" : 3,
"addBatch" : 234,
"walCommit" : 30
},
"currentWatermark" : 0,
"stateOperators" : [ ],
"sources" : [ {
"description" : "KafkaSource[Subscribe[topic-14]]",
"startOffset" : {
"topic-14" : {
"2" : 0,
"4" : 1,
"1" : 0,
"3" : 0,
"0" : 0
}
},
"endOffset" : {
"topic-14" : {
"2" : 1,
"4" : 2,
"1" : 0,
"3" : 0,
"0" : 1
}
},
"numRecords" : 3,
"inputRowsPerSecond" : 230.76923076923077,
"processedRowsPerSecond" : 10.869565217391303
} ]
}
```
Additionally, in order to make it possible to correlate progress updates across restarts, we change the `id` field from an integer that is unique with in the JVM to a `UUID` that is globally unique.
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #15954 from marmbrus/queryProgress.
2016-11-29 20:24:17 -05:00
|
|
|
finally:
|
|
|
|
q.stop()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2016-04-20 13:32:01 -04:00
|
|
|
def test_stream_await_termination(self):
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
2016-06-15 13:46:02 -04:00
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
2016-04-20 13:32:01 -04:00
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
2016-06-15 13:46:02 -04:00
|
|
|
q = df.writeStream\
|
2016-06-14 20:58:45 -04:00
|
|
|
.start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
|
2016-04-20 13:32:01 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
self.assertTrue(q.isActive)
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.awaitTermination("hello")
|
2016-04-28 18:22:28 -04:00
|
|
|
self.fail("Expected a value exception")
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
now = time.time()
|
|
|
|
# test should take at least 2 seconds
|
2016-06-15 13:46:02 -04:00
|
|
|
res = q.awaitTermination(2.6)
|
2016-04-28 18:22:28 -04:00
|
|
|
duration = time.time() - now
|
|
|
|
self.assertTrue(duration >= 2)
|
|
|
|
self.assertFalse(res)
|
|
|
|
finally:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2016-12-05 14:36:11 -05:00
|
|
|
def test_stream_exception(self):
|
|
|
|
sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
|
|
|
sq = sdf.writeStream.format('memory').queryName('query_explain').start()
|
|
|
|
try:
|
|
|
|
sq.processAllAvailable()
|
|
|
|
self.assertEqual(sq.exception(), None)
|
|
|
|
finally:
|
|
|
|
sq.stop()
|
|
|
|
|
|
|
|
from pyspark.sql.functions import col, udf
|
|
|
|
from pyspark.sql.utils import StreamingQueryException
|
|
|
|
bad_udf = udf(lambda x: 1 / 0)
|
|
|
|
sq = sdf.select(bad_udf(col("value")))\
|
|
|
|
.writeStream\
|
|
|
|
.format('memory')\
|
|
|
|
.queryName('this_query')\
|
|
|
|
.start()
|
|
|
|
try:
|
|
|
|
# Process some data to fail the query
|
|
|
|
sq.processAllAvailable()
|
|
|
|
self.fail("bad udf should fail the query")
|
|
|
|
except StreamingQueryException as e:
|
|
|
|
# This is expected
|
|
|
|
self.assertTrue("ZeroDivisionError" in e.desc)
|
|
|
|
finally:
|
|
|
|
sq.stop()
|
|
|
|
self.assertTrue(type(sq.exception()) is StreamingQueryException)
|
|
|
|
self.assertTrue("ZeroDivisionError" in sq.exception().desc)
|
|
|
|
|
2016-04-28 18:22:28 -04:00
|
|
|
def test_query_manager_await_termination(self):
|
2016-06-14 20:58:45 -04:00
|
|
|
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
|
2016-06-15 13:46:02 -04:00
|
|
|
for q in self.spark._wrapped.streams.active:
|
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
tmpPath = tempfile.mkdtemp()
|
2016-04-20 13:32:01 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
2016-04-28 18:22:28 -04:00
|
|
|
self.assertTrue(df.isStreaming)
|
|
|
|
out = os.path.join(tmpPath, 'out')
|
|
|
|
chk = os.path.join(tmpPath, 'chk')
|
2016-06-15 13:46:02 -04:00
|
|
|
q = df.writeStream\
|
2016-06-14 20:58:45 -04:00
|
|
|
.start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-06-15 13:46:02 -04:00
|
|
|
self.assertTrue(q.isActive)
|
2016-04-28 18:22:28 -04:00
|
|
|
try:
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark._wrapped.streams.awaitAnyTermination("hello")
|
2016-04-28 18:22:28 -04:00
|
|
|
self.fail("Expected a value exception")
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
now = time.time()
|
|
|
|
# test should take at least 2 seconds
|
2016-05-11 14:24:16 -04:00
|
|
|
res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
|
2016-04-28 18:22:28 -04:00
|
|
|
duration = time.time() - now
|
|
|
|
self.assertTrue(duration >= 2)
|
|
|
|
self.assertFalse(res)
|
|
|
|
finally:
|
2016-06-15 13:46:02 -04:00
|
|
|
q.stop()
|
2016-04-28 18:22:28 -04:00
|
|
|
shutil.rmtree(tmpPath)
|
2016-04-20 13:32:01 -04:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_help_command(self):
|
|
|
|
# Regression test for SPARK-5464
|
|
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.read.json(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
# render_doc() reproduces the help() exception without printing output
|
|
|
|
pydoc.render_doc(df)
|
|
|
|
pydoc.render_doc(df.foo)
|
|
|
|
pydoc.render_doc(df.take(1))
|
|
|
|
|
2015-04-16 20:33:57 -04:00
|
|
|
def test_access_column(self):
|
|
|
|
df = self.df
|
|
|
|
self.assertTrue(isinstance(df.key, Column))
|
|
|
|
self.assertTrue(isinstance(df['key'], Column))
|
|
|
|
self.assertTrue(isinstance(df[0], Column))
|
|
|
|
self.assertRaises(IndexError, lambda: df[2])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertRaises(AnalysisException, lambda: df["bad_key"])
|
2015-04-16 20:33:57 -04:00
|
|
|
self.assertRaises(TypeError, lambda: df[{}])
|
|
|
|
|
2015-07-01 19:43:18 -04:00
|
|
|
def test_column_name_with_non_ascii(self):
|
2016-05-18 14:18:33 -04:00
|
|
|
if sys.version >= '3':
|
|
|
|
columnName = "数量"
|
|
|
|
self.assertTrue(isinstance(columnName, str))
|
|
|
|
else:
|
|
|
|
columnName = unicode("数量", "utf-8")
|
|
|
|
self.assertTrue(isinstance(columnName, unicode))
|
|
|
|
schema = StructType([StructField(columnName, LongType(), True)])
|
|
|
|
df = self.spark.createDataFrame([(1,)], schema)
|
|
|
|
self.assertEqual(schema, df.schema)
|
2015-07-01 19:43:18 -04:00
|
|
|
self.assertEqual("DataFrame[数量: bigint]", str(df))
|
|
|
|
self.assertEqual([("数量", 'bigint')], df.dtypes)
|
|
|
|
self.assertEqual(1, df.select("数量").first()[0])
|
|
|
|
self.assertEqual(1, df.select(df["数量"]).first()[0])
|
|
|
|
|
2015-04-16 20:33:57 -04:00
|
|
|
def test_access_nested_types(self):
|
|
|
|
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
|
|
|
|
self.assertEqual(1, df.select(df.l[0]).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.r.a).first()[0])
|
|
|
|
self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
|
|
|
|
self.assertEqual("v", df.select(df.d["k"]).first()[0])
|
|
|
|
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
|
|
|
|
|
2015-05-08 14:49:38 -04:00
|
|
|
def test_field_accessor(self):
|
|
|
|
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
|
|
|
|
self.assertEqual(1, df.select(df.l[0]).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.r["a"]).first()[0])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertEqual(1, df.select(df["r.a"]).first()[0])
|
2015-05-08 14:49:38 -04:00
|
|
|
self.assertEqual("b", df.select(df.r["b"]).first()[0])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertEqual("b", df.select(df["r.b"]).first()[0])
|
2015-05-08 14:49:38 -04:00
|
|
|
self.assertEqual("v", df.select(df.d["k"]).first()[0])
|
|
|
|
|
2015-02-18 17:17:04 -05:00
|
|
|
def test_infer_long_type(self):
|
|
|
|
longrow = [Row(f1='a', f2=100000000000000)]
|
|
|
|
df = self.sc.parallelize(longrow).toDF()
|
|
|
|
self.assertEqual(df.schema.fields[1].dataType, LongType())
|
|
|
|
|
|
|
|
# this saving as Parquet caused issues as well.
|
|
|
|
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
|
2016-01-04 21:02:38 -05:00
|
|
|
df.write.parquet(output_dir)
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.read.parquet(output_dir)
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual('a', df1.first().f1)
|
|
|
|
self.assertEqual(100000000000000, df1.first().f2)
|
2015-02-18 17:17:04 -05:00
|
|
|
|
|
|
|
self.assertEqual(_infer_type(1), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**10), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**20), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**31 - 1), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**31), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**61), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**71), LongType())
|
|
|
|
|
2015-04-21 03:08:18 -04:00
|
|
|
def test_filter_with_datetime(self):
|
|
|
|
time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
|
|
|
|
date = time.date()
|
|
|
|
row = Row(date=date, time=time)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-04-21 03:08:18 -04:00
|
|
|
self.assertEqual(1, df.filter(df.date == date).count())
|
|
|
|
self.assertEqual(1, df.filter(df.time == time).count())
|
|
|
|
self.assertEqual(0, df.filter(df.date > date).count())
|
|
|
|
self.assertEqual(0, df.filter(df.time > time).count())
|
|
|
|
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
def test_filter_with_datetime_timezone(self):
|
|
|
|
dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
|
|
|
|
dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
|
|
|
|
row = Row(date=dt1)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
self.assertEqual(0, df.filter(df.date == dt2).count())
|
|
|
|
self.assertEqual(1, df.filter(df.date > dt2).count())
|
|
|
|
self.assertEqual(0, df.filter(df.date < dt2).count())
|
|
|
|
|
2015-06-11 04:00:41 -04:00
|
|
|
def test_time_with_timezone(self):
|
|
|
|
day = datetime.date.today()
|
|
|
|
now = datetime.datetime.now()
|
2015-07-10 16:05:23 -04:00
|
|
|
ts = time.mktime(now.timetuple())
|
2015-06-11 04:00:41 -04:00
|
|
|
# class in __main__ is not serializable
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
from pyspark.sql.tests import UTCOffsetTimezone
|
|
|
|
utc = UTCOffsetTimezone()
|
2015-07-10 16:05:23 -04:00
|
|
|
utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds
|
2015-07-10 20:44:21 -04:00
|
|
|
# add microseconds to utcnow (keeping year,month,day,hour,minute,second)
|
2015-07-10 16:05:23 -04:00
|
|
|
utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(day, now, utcnow)])
|
2015-06-11 04:00:41 -04:00
|
|
|
day1, now1, utcnow1 = df.first()
|
2015-07-09 17:43:38 -04:00
|
|
|
self.assertEqual(day1, day)
|
|
|
|
self.assertEqual(now, now1)
|
|
|
|
self.assertEqual(now, utcnow1)
|
2015-06-11 04:00:41 -04:00
|
|
|
|
2017-03-09 13:34:54 -05:00
|
|
|
# regression test for SPARK-19561
|
|
|
|
def test_datetime_at_epoch(self):
|
|
|
|
epoch = datetime.datetime.fromtimestamp(0)
|
|
|
|
df = self.spark.createDataFrame([Row(date=epoch)])
|
|
|
|
first = df.select('date', lit(epoch).alias('lit_date')).first()
|
|
|
|
self.assertEqual(first['date'], epoch)
|
|
|
|
self.assertEqual(first['lit_date'], epoch)
|
|
|
|
|
2015-07-08 21:22:53 -04:00
|
|
|
def test_decimal(self):
|
|
|
|
from decimal import Decimal
|
|
|
|
schema = StructType([StructField("decimal", DecimalType(10, 5))])
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema)
|
2015-07-08 21:22:53 -04:00
|
|
|
row = df.select(df.decimal + 1).first()
|
|
|
|
self.assertEqual(row[0], Decimal("4.14159"))
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
df.write.parquet(tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
df2 = self.spark.read.parquet(tmpPath)
|
2015-07-08 21:22:53 -04:00
|
|
|
row = df2.first()
|
|
|
|
self.assertEqual(row[0], Decimal("3.14159"))
|
|
|
|
|
2015-03-30 23:47:10 -04:00
|
|
|
def test_dropna(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
# shouldn't drop a non-null row
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', 50, 80.1)], schema).dropna().count(),
|
|
|
|
1)
|
|
|
|
|
|
|
|
# dropping rows with a single null value
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 80.1)], schema).dropna().count(),
|
|
|
|
0)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# if how = 'all', only drop rows if all values are null
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
|
|
|
|
1)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(None, None, None)], schema).dropna(how='all').count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# how and subset
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
|
|
|
|
1)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# threshold
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
|
|
|
|
1)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, None)], schema).dropna(thresh=2).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# threshold and subset
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
|
|
|
|
1)
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# thresh should take precedence over how
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(u'Alice', 50, None)], schema).dropna(
|
|
|
|
how='any', thresh=2, subset=['name', 'age']).count(),
|
|
|
|
1)
|
|
|
|
|
|
|
|
def test_fillna(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
# fillna shouldn't change non-null values
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# fillna with int
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, 50.0)
|
|
|
|
|
|
|
|
# fillna with double
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, 50.1)
|
|
|
|
|
|
|
|
# fillna with string
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first()
|
2015-03-30 23:47:10 -04:00
|
|
|
self.assertEqual(row.name, u"hello")
|
|
|
|
self.assertEqual(row.age, None)
|
|
|
|
|
|
|
|
# fillna with subset specified for numeric cols
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
|
|
|
|
self.assertEqual(row.name, None)
|
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, None)
|
|
|
|
|
|
|
|
# fillna with subset specified for numeric cols
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-03-30 23:47:10 -04:00
|
|
|
[(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
|
|
|
|
self.assertEqual(row.name, "haha")
|
|
|
|
self.assertEqual(row.age, None)
|
|
|
|
self.assertEqual(row.height, None)
|
|
|
|
|
2015-05-07 04:00:29 -04:00
|
|
|
def test_bitwise_operations(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
row = Row(a=170, b=75)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-05-07 04:00:29 -04:00
|
|
|
result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 & 75, result['(a & b)'])
|
|
|
|
result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 | 75, result['(a | b)'])
|
|
|
|
result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 ^ 75, result['(a ^ b)'])
|
|
|
|
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(~75, result['~b'])
|
|
|
|
|
2015-07-25 03:34:59 -04:00
|
|
|
def test_expr(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
row = Row(a="length string", b=75)
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([row])
|
2015-07-25 03:34:59 -04:00
|
|
|
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
|
2015-11-10 14:06:29 -05:00
|
|
|
self.assertEqual(13, result["length(a)"])
|
2015-07-25 03:34:59 -04:00
|
|
|
|
2015-05-12 13:23:41 -04:00
|
|
|
def test_replace(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
# replace with int
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
|
2015-05-12 13:23:41 -04:00
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
self.assertEqual(row.height, 20.0)
|
|
|
|
|
|
|
|
# replace with double
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
|
|
|
|
self.assertEqual(row.age, 82)
|
|
|
|
self.assertEqual(row.height, 82.1)
|
|
|
|
|
|
|
|
# replace with string
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
|
|
|
|
self.assertEqual(row.name, u"Ann")
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# replace with subset specified by a string of a column name w/ actual change
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
|
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
|
|
|
|
# replace with subset specified by a string of a column name w/o actual change
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# replace with subset specified with one column replaced, another column not in subset
|
|
|
|
# stays unchanged.
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
|
|
|
|
self.assertEqual(row.name, u'Alice')
|
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
self.assertEqual(row.height, 10.0)
|
|
|
|
|
|
|
|
# replace with subset specified but no column will be replaced
|
2016-05-11 14:24:16 -04:00
|
|
|
row = self.spark.createDataFrame(
|
2015-05-12 13:23:41 -04:00
|
|
|
[(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
|
|
|
|
self.assertEqual(row.name, u'Alice')
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
self.assertEqual(row.height, None)
|
|
|
|
|
2017-04-05 14:47:40 -04:00
|
|
|
# replace with lists
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
|
|
|
|
self.assertTupleEqual(row, (u'Ann', 10, 80.1))
|
|
|
|
|
|
|
|
# replace with dict
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 11, 80.1))
|
|
|
|
|
|
|
|
# test backward compatibility with dummy value
|
|
|
|
dummy_value = 1
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
|
|
|
|
self.assertTupleEqual(row, (u'Bob', 10, 80.1))
|
|
|
|
|
|
|
|
# test dict with mixed numerics
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', -10, 90.5))
|
|
|
|
|
|
|
|
# replace with tuples
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
|
|
|
|
self.assertTupleEqual(row, (u'Bob', 10, 80.1))
|
|
|
|
|
|
|
|
# replace multiple columns
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.0))
|
|
|
|
|
|
|
|
# test for mixed numerics
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.5))
|
|
|
|
|
|
|
|
row = self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
|
|
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.5))
|
|
|
|
|
|
|
|
# replace with boolean
|
|
|
|
row = (self
|
|
|
|
.spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
|
|
|
|
.selectExpr("name = 'Bob'", 'age <= 15')
|
|
|
|
.replace(False, True).first())
|
|
|
|
self.assertTupleEqual(row, (True, True))
|
|
|
|
|
|
|
|
# should fail if subset is not list, tuple or None
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()
|
|
|
|
|
|
|
|
# should fail if to_replace and value have different length
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()
|
|
|
|
|
|
|
|
# should fail if when received unexpected type
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
from datetime import datetime
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()
|
|
|
|
|
|
|
|
# should fail if provided mixed type replacements
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
self.spark.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
|
|
|
|
|
2015-06-30 19:17:46 -04:00
|
|
|
def test_capture_analysis_exception(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
|
2015-06-30 19:17:46 -04:00
|
|
|
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
|
2016-03-28 15:31:12 -04:00
|
|
|
|
|
|
|
def test_capture_parse_exception(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
|
2015-06-30 19:17:46 -04:00
|
|
|
|
2015-07-19 03:32:56 -04:00
|
|
|
def test_capture_illegalargument_exception(self):
|
|
|
|
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
|
2016-05-11 14:24:16 -04:00
|
|
|
lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
|
|
|
|
df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
|
2015-07-19 03:32:56 -04:00
|
|
|
self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
|
|
|
|
lambda: df.select(sha2(df.a, 1024)).collect())
|
2015-10-29 00:45:00 -04:00
|
|
|
try:
|
|
|
|
df.select(sha2(df.a, 1024)).collect()
|
|
|
|
except IllegalArgumentException as e:
|
|
|
|
self.assertRegexpMatches(e.desc, "1024 is not in the permitted values")
|
|
|
|
self.assertRegexpMatches(e.stackTrace,
|
|
|
|
"org.apache.spark.sql.functions")
|
2015-07-19 03:32:56 -04:00
|
|
|
|
2015-08-19 16:56:40 -04:00
|
|
|
def test_with_column_with_existing_name(self):
|
|
|
|
keys = self.df.withColumn("key", self.df.key).select("key").collect()
|
|
|
|
self.assertEqual([r.key for r in keys], list(range(100)))
|
|
|
|
|
2015-09-02 16:36:36 -04:00
|
|
|
# regression test for SPARK-10417
|
|
|
|
def test_column_iterator(self):
|
|
|
|
|
|
|
|
def foo():
|
|
|
|
for x in self.df.key:
|
|
|
|
break
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, foo)
|
|
|
|
|
2015-09-22 02:36:41 -04:00
|
|
|
# add test for SPARK-10577 (test broadcast join hint)
|
|
|
|
def test_functions_broadcast(self):
|
|
|
|
from pyspark.sql.functions import broadcast
|
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
|
|
|
|
df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
|
2015-09-22 02:36:41 -04:00
|
|
|
|
|
|
|
# equijoin - should be converted into broadcast join
|
|
|
|
plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
|
|
|
|
self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
|
|
|
|
|
|
|
|
# no join key -- should not be a broadcast join
|
2016-10-14 21:24:47 -04:00
|
|
|
plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
|
2015-09-22 02:36:41 -04:00
|
|
|
self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
|
|
|
|
|
|
|
|
# planner should not crash without a join
|
|
|
|
broadcast(df1)._jdf.queryExecution().executedPlan()
|
|
|
|
|
2016-03-08 17:00:03 -05:00
|
|
|
def test_toDF_with_schema_string(self):
|
|
|
|
data = [Row(key=i, value=str(i)) for i in range(100)]
|
|
|
|
rdd = self.sc.parallelize(data, 5)
|
|
|
|
|
|
|
|
df = rdd.toDF("key: int, value: string")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
|
|
|
|
self.assertEqual(df.collect(), data)
|
|
|
|
|
|
|
|
# different but compatible field types can be used.
|
|
|
|
df = rdd.toDF("key: string, value: string")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
|
|
|
|
|
|
|
|
# field names can differ.
|
|
|
|
df = rdd.toDF(" a: int, b: string ")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
|
|
|
|
self.assertEqual(df.collect(), data)
|
|
|
|
|
|
|
|
# number of fields must match.
|
|
|
|
self.assertRaisesRegexp(Exception, "Length of object",
|
|
|
|
lambda: rdd.toDF("key: int").collect())
|
|
|
|
|
|
|
|
# field types mismatch will cause exception at runtime.
|
|
|
|
self.assertRaisesRegexp(Exception, "FloatType can not accept",
|
|
|
|
lambda: rdd.toDF("key: float, value: string").collect())
|
|
|
|
|
|
|
|
# flat schema values will be wrapped into row.
|
|
|
|
df = rdd.map(lambda row: row.key).toDF("int")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
|
|
|
|
|
|
|
|
# users can use DataType directly instead of data type string.
|
|
|
|
df = rdd.map(lambda row: row.key).toDF(IntegerType())
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
|
|
|
|
|
2016-10-12 13:09:49 -04:00
|
|
|
# Regression test for invalid join methods when on is None, Spark-14761
|
|
|
|
def test_invalid_join_method(self):
|
|
|
|
df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])
|
|
|
|
df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"])
|
|
|
|
self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type"))
|
|
|
|
|
2016-10-14 21:24:47 -04:00
|
|
|
# Cartesian products require cross join syntax
|
|
|
|
def test_require_cross(self):
|
|
|
|
from pyspark.sql.functions import broadcast
|
|
|
|
|
|
|
|
df1 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
|
|
|
|
df2 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
|
|
|
|
|
|
|
|
# joins without conditions require cross join syntax
|
|
|
|
self.assertRaises(AnalysisException, lambda: df1.join(df2).collect())
|
|
|
|
|
|
|
|
# works with crossJoin
|
|
|
|
self.assertEqual(1, df1.crossJoin(df2).count())
|
|
|
|
|
2016-04-29 19:41:13 -04:00
|
|
|
def test_conf(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 23:46:07 -04:00
|
|
|
spark.conf.set("bogo", "sipeo")
|
2016-05-11 14:24:16 -04:00
|
|
|
self.assertEqual(spark.conf.get("bogo"), "sipeo")
|
2016-04-29 23:46:07 -04:00
|
|
|
spark.conf.set("bogo", "ta")
|
2016-04-29 19:41:13 -04:00
|
|
|
self.assertEqual(spark.conf.get("bogo"), "ta")
|
|
|
|
self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
|
|
|
|
self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
|
|
|
|
self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set"))
|
|
|
|
spark.conf.unset("bogo")
|
|
|
|
self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
|
|
|
|
|
|
|
|
def test_current_database(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
self.assertEquals(spark.catalog.currentDatabase(), "default")
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
|
|
|
spark.catalog.setCurrentDatabase("some_db")
|
|
|
|
self.assertEquals(spark.catalog.currentDatabase(), "some_db")
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
|
|
|
|
|
|
|
|
def test_list_databases(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
databases = [db.name for db in spark.catalog.listDatabases()]
|
|
|
|
self.assertEquals(databases, ["default"])
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
|
|
|
databases = [db.name for db in spark.catalog.listDatabases()]
|
|
|
|
self.assertEquals(sorted(databases), ["default", "some_db"])
|
|
|
|
|
|
|
|
def test_list_tables(self):
|
|
|
|
from pyspark.sql.catalog import Table
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
|
|
|
self.assertEquals(spark.catalog.listTables(), [])
|
|
|
|
self.assertEquals(spark.catalog.listTables("some_db"), [])
|
2016-05-17 21:01:59 -04:00
|
|
|
spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
|
2017-01-22 23:37:37 -05:00
|
|
|
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
|
|
|
|
spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
|
2016-04-29 19:41:13 -04:00
|
|
|
tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
|
|
|
|
tablesDefault = sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
|
|
|
|
tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
|
|
|
|
self.assertEquals(tables, tablesDefault)
|
|
|
|
self.assertEquals(len(tables), 2)
|
|
|
|
self.assertEquals(len(tablesSomeDb), 2)
|
|
|
|
self.assertEquals(tables[0], Table(
|
|
|
|
name="tab1",
|
|
|
|
database="default",
|
|
|
|
description=None,
|
|
|
|
tableType="MANAGED",
|
|
|
|
isTemporary=False))
|
|
|
|
self.assertEquals(tables[1], Table(
|
|
|
|
name="temp_tab",
|
|
|
|
database=None,
|
|
|
|
description=None,
|
|
|
|
tableType="TEMPORARY",
|
|
|
|
isTemporary=True))
|
|
|
|
self.assertEquals(tablesSomeDb[0], Table(
|
|
|
|
name="tab2",
|
|
|
|
database="some_db",
|
|
|
|
description=None,
|
|
|
|
tableType="MANAGED",
|
|
|
|
isTemporary=False))
|
|
|
|
self.assertEquals(tablesSomeDb[1], Table(
|
|
|
|
name="temp_tab",
|
|
|
|
database=None,
|
|
|
|
description=None,
|
|
|
|
tableType="TEMPORARY",
|
|
|
|
isTemporary=True))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.listTables("does_not_exist"))
|
|
|
|
|
|
|
|
def test_list_functions(self):
|
|
|
|
from pyspark.sql.catalog import Function
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
|
|
|
functions = dict((f.name, f) for f in spark.catalog.listFunctions())
|
|
|
|
functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
|
2016-06-27 14:50:34 -04:00
|
|
|
self.assertTrue(len(functions) > 200)
|
|
|
|
self.assertTrue("+" in functions)
|
|
|
|
self.assertTrue("like" in functions)
|
|
|
|
self.assertTrue("month" in functions)
|
2017-02-07 09:50:30 -05:00
|
|
|
self.assertTrue("to_date" in functions)
|
|
|
|
self.assertTrue("to_timestamp" in functions)
|
2016-06-27 14:50:34 -04:00
|
|
|
self.assertTrue("to_unix_timestamp" in functions)
|
|
|
|
self.assertTrue("current_database" in functions)
|
|
|
|
self.assertEquals(functions["+"], Function(
|
|
|
|
name="+",
|
|
|
|
description=None,
|
|
|
|
className="org.apache.spark.sql.catalyst.expressions.Add",
|
|
|
|
isTemporary=True))
|
2016-04-29 19:41:13 -04:00
|
|
|
self.assertEquals(functions, functionsDefault)
|
|
|
|
spark.catalog.registerFunction("temp_func", lambda x: str(x))
|
|
|
|
spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
|
|
|
|
spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
|
|
|
|
newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
|
|
|
|
newFunctionsSomeDb = dict((f.name, f) for f in spark.catalog.listFunctions("some_db"))
|
|
|
|
self.assertTrue(set(functions).issubset(set(newFunctions)))
|
|
|
|
self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
|
|
|
|
self.assertTrue("temp_func" in newFunctions)
|
|
|
|
self.assertTrue("func1" in newFunctions)
|
|
|
|
self.assertTrue("func2" not in newFunctions)
|
|
|
|
self.assertTrue("temp_func" in newFunctionsSomeDb)
|
|
|
|
self.assertTrue("func1" not in newFunctionsSomeDb)
|
|
|
|
self.assertTrue("func2" in newFunctionsSomeDb)
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.listFunctions("does_not_exist"))
|
|
|
|
|
|
|
|
def test_list_columns(self):
|
|
|
|
from pyspark.sql.catalog import Column
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-04-29 19:41:13 -04:00
|
|
|
spark.catalog._reset()
|
|
|
|
spark.sql("CREATE DATABASE some_db")
|
2017-01-22 23:37:37 -05:00
|
|
|
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
|
|
|
|
spark.sql("CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet")
|
2016-04-29 19:41:13 -04:00
|
|
|
columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
|
|
|
|
columnsDefault = sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
|
|
|
|
self.assertEquals(columns, columnsDefault)
|
|
|
|
self.assertEquals(len(columns), 2)
|
|
|
|
self.assertEquals(columns[0], Column(
|
|
|
|
name="age",
|
|
|
|
description=None,
|
|
|
|
dataType="int",
|
|
|
|
nullable=True,
|
|
|
|
isPartition=False,
|
|
|
|
isBucket=False))
|
|
|
|
self.assertEquals(columns[1], Column(
|
|
|
|
name="name",
|
|
|
|
description=None,
|
|
|
|
dataType="string",
|
|
|
|
nullable=True,
|
|
|
|
isPartition=False,
|
|
|
|
isBucket=False))
|
|
|
|
columns2 = sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
|
|
|
|
self.assertEquals(len(columns2), 2)
|
|
|
|
self.assertEquals(columns2[0], Column(
|
|
|
|
name="nickname",
|
|
|
|
description=None,
|
|
|
|
dataType="string",
|
|
|
|
nullable=True,
|
|
|
|
isPartition=False,
|
|
|
|
isBucket=False))
|
|
|
|
self.assertEquals(columns2[1], Column(
|
|
|
|
name="tolerance",
|
|
|
|
description=None,
|
|
|
|
dataType="float",
|
|
|
|
nullable=True,
|
|
|
|
isPartition=False,
|
|
|
|
isBucket=False))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"tab2",
|
|
|
|
lambda: spark.catalog.listColumns("tab2"))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.listColumns("does_not_exist"))
|
|
|
|
|
|
|
|
def test_cache(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
spark = self.spark
|
2016-05-17 21:01:59 -04:00
|
|
|
spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1")
|
|
|
|
spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2")
|
2016-04-29 19:41:13 -04:00
|
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
|
|
spark.catalog.cacheTable("tab1")
|
|
|
|
self.assertTrue(spark.catalog.isCached("tab1"))
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
|
|
spark.catalog.cacheTable("tab2")
|
|
|
|
spark.catalog.uncacheTable("tab1")
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
|
|
self.assertTrue(spark.catalog.isCached("tab2"))
|
|
|
|
spark.catalog.clearCache()
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.isCached("does_not_exist"))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.cacheTable("does_not_exist"))
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
"does_not_exist",
|
|
|
|
lambda: spark.catalog.uncacheTable("does_not_exist"))
|
|
|
|
|
2016-10-07 03:27:55 -04:00
|
|
|
def test_read_text_file_list(self):
|
|
|
|
df = self.spark.read.text(['python/test_support/sql/text-test.txt',
|
|
|
|
'python/test_support/sql/text-test.txt'])
|
|
|
|
count = df.count()
|
|
|
|
self.assertEquals(count, 4)
|
|
|
|
|
2016-10-11 02:29:52 -04:00
|
|
|
def test_BinaryType_serialization(self):
|
|
|
|
# Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808
|
|
|
|
schema = StructType([StructField('mybytes', BinaryType())])
|
|
|
|
data = [[bytearray(b'here is my data')],
|
|
|
|
[bytearray(b'and here is some more')]]
|
|
|
|
df = self.spark.createDataFrame(data, schema=schema)
|
|
|
|
df.collect()
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2016-06-28 10:54:44 -04:00
|
|
|
class HiveSparkSubmitTests(SparkSubmitTests):
|
|
|
|
|
|
|
|
def test_hivecontext(self):
|
|
|
|
# This test checks that HiveContext is using Hive metastore (SPARK-16224).
|
|
|
|
# It sets a metastore url and checks if there is a derby dir created by
|
|
|
|
# Hive metastore. If this derby dir exists, HiveContext is using
|
|
|
|
# Hive metastore.
|
|
|
|
metastore_path = os.path.join(tempfile.mkdtemp(), "spark16224_metastore_db")
|
|
|
|
metastore_URL = "jdbc:derby:;databaseName=" + metastore_path + ";create=true"
|
|
|
|
hive_site_dir = os.path.join(self.programDir, "conf")
|
|
|
|
hive_site_file = self.createTempFile("hive-site.xml", ("""
|
|
|
|
|<configuration>
|
|
|
|
| <property>
|
|
|
|
| <name>javax.jdo.option.ConnectionURL</name>
|
|
|
|
| <value>%s</value>
|
|
|
|
| </property>
|
|
|
|
|</configuration>
|
|
|
|
""" % metastore_URL).lstrip(), "conf")
|
|
|
|
script = self.createTempFile("test.py", """
|
|
|
|
|import os
|
|
|
|
|
|
|
|
|
|from pyspark.conf import SparkConf
|
|
|
|
|from pyspark.context import SparkContext
|
|
|
|
|from pyspark.sql import HiveContext
|
|
|
|
|
|
|
|
|
|conf = SparkConf()
|
|
|
|
|sc = SparkContext(conf=conf)
|
|
|
|
|hive_context = HiveContext(sc)
|
|
|
|
|print(hive_context.sql("show databases").collect())
|
|
|
|
""")
|
|
|
|
proc = subprocess.Popen(
|
|
|
|
[self.sparkSubmit, "--master", "local-cluster[1,1,1024]",
|
|
|
|
"--driver-class-path", hive_site_dir, script],
|
|
|
|
stdout=subprocess.PIPE)
|
|
|
|
out, err = proc.communicate()
|
|
|
|
self.assertEqual(0, proc.returncode)
|
|
|
|
self.assertIn("default", out.decode('utf-8'))
|
|
|
|
self.assertTrue(os.path.exists(metastore_path))
|
|
|
|
|
|
|
|
|
2017-01-12 07:53:31 -05:00
|
|
|
class SQLTests2(ReusedPySparkTestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ReusedPySparkTestCase.setUpClass()
|
|
|
|
cls.spark = SparkSession(cls.sc)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ReusedPySparkTestCase.tearDownClass()
|
|
|
|
cls.spark.stop()
|
|
|
|
|
|
|
|
# We can't include this test into SQLTests because we will stop class's SparkContext and cause
|
|
|
|
# other tests failed.
|
|
|
|
def test_sparksession_with_stopped_sparkcontext(self):
|
|
|
|
self.sc.stop()
|
|
|
|
sc = SparkContext('local[4]', self.sc.appName)
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
|
|
|
df = spark.createDataFrame([(1, 2)], ["c", "c"])
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
|
2017-01-31 21:03:39 -05:00
|
|
|
class UDFInitializationTests(unittest.TestCase):
|
|
|
|
def tearDown(self):
|
|
|
|
if SparkSession._instantiatedSession is not None:
|
|
|
|
SparkSession._instantiatedSession.stop()
|
|
|
|
|
|
|
|
if SparkContext._active_spark_context is not None:
|
|
|
|
SparkContext._active_spark_contex.stop()
|
|
|
|
|
|
|
|
def test_udf_init_shouldnt_initalize_context(self):
|
|
|
|
from pyspark.sql.functions import UserDefinedFunction
|
|
|
|
|
|
|
|
UserDefinedFunction(lambda x: x, StringType())
|
|
|
|
|
|
|
|
self.assertIsNone(
|
|
|
|
SparkContext._active_spark_context,
|
|
|
|
"SparkContext shouldn't be initialized when UserDefinedFunction is created."
|
|
|
|
)
|
|
|
|
self.assertIsNone(
|
|
|
|
SparkSession._instantiatedSession,
|
|
|
|
"SparkSession shouldn't be initialized when UserDefinedFunction is created."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2015-02-10 20:29:52 -05:00
|
|
|
class HiveContextSQLTests(ReusedPySparkTestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ReusedPySparkTestCase.setUpClass()
|
|
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
2015-02-17 18:44:37 -05:00
|
|
|
try:
|
|
|
|
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
|
|
|
|
except py4j.protocol.Py4JError:
|
2015-06-03 03:23:34 -04:00
|
|
|
cls.tearDownClass()
|
2015-05-23 11:30:05 -04:00
|
|
|
raise unittest.SkipTest("Hive is not available")
|
2015-04-16 19:20:57 -04:00
|
|
|
except TypeError:
|
2015-06-03 03:23:34 -04:00
|
|
|
cls.tearDownClass()
|
2015-05-23 11:30:05 -04:00
|
|
|
raise unittest.SkipTest("Hive is not available")
|
2015-02-10 20:29:52 -05:00
|
|
|
os.unlink(cls.tempdir.name)
|
2016-05-11 14:24:16 -04:00
|
|
|
cls.spark = HiveContext._createForTesting(cls.sc)
|
2015-02-10 20:29:52 -05:00
|
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
2015-02-14 02:03:22 -05:00
|
|
|
cls.df = cls.sc.parallelize(cls.testData).toDF()
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ReusedPySparkTestCase.tearDownClass()
|
|
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
|
|
|
|
def test_save_and_load_table(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json")
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("DROP TABLE externalJsonTable")
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
|
2015-02-10 20:29:52 -05:00
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.createExternalTable("externalJsonTable", source="json",
|
|
|
|
schema=schema, path=tmpPath,
|
|
|
|
noUse="this options will not be used")
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.select("value").collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("DROP TABLE savedJsonTable")
|
|
|
|
self.spark.sql("DROP TABLE externalJsonTable")
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2016-05-11 14:24:16 -04:00
|
|
|
defaultDataSourceName = self.spark.getConf("spark.sql.sources.default",
|
|
|
|
"org.apache.spark.sql.parquet")
|
|
|
|
self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
|
2016-05-11 14:24:16 -04:00
|
|
|
actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
2016-05-11 14:24:16 -04:00
|
|
|
sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2016-05-11 14:24:16 -04:00
|
|
|
self.spark.sql("DROP TABLE savedJsonTable")
|
|
|
|
self.spark.sql("DROP TABLE externalJsonTable")
|
|
|
|
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2015-05-23 11:30:05 -04:00
|
|
|
def test_window_functions(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
2015-05-23 11:30:05 -04:00
|
|
|
w = Window.partitionBy("value").orderBy("key")
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
sel = df.select(df.value, df.key,
|
|
|
|
F.max("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.min("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.row_number().over(w),
|
2015-05-23 11:30:05 -04:00
|
|
|
F.rank().over(w),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.dense_rank().over(w),
|
2015-05-23 11:30:05 -04:00
|
|
|
F.ntile(2).over(w))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [
|
|
|
|
("1", 1, 1, 1, 1, 1, 1, 1, 1),
|
|
|
|
("2", 1, 1, 1, 3, 1, 1, 1, 1),
|
|
|
|
("2", 1, 2, 1, 3, 2, 1, 1, 1),
|
|
|
|
("2", 2, 2, 2, 3, 3, 3, 2, 2)
|
|
|
|
]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
2015-05-19 17:23:28 -04:00
|
|
|
|
2015-08-14 16:55:29 -04:00
|
|
|
def test_window_functions_without_partitionBy(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
2015-08-14 16:55:29 -04:00
|
|
|
w = Window.orderBy("key", df.value)
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
sel = df.select(df.value, df.key,
|
|
|
|
F.max("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.min("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.row_number().over(w),
|
2015-08-14 16:55:29 -04:00
|
|
|
F.rank().over(w),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.dense_rank().over(w),
|
2015-08-14 16:55:29 -04:00
|
|
|
F.ntile(2).over(w))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [
|
|
|
|
("1", 1, 1, 1, 4, 1, 1, 1, 1),
|
|
|
|
("2", 1, 1, 1, 4, 2, 2, 2, 1),
|
|
|
|
("2", 1, 2, 1, 4, 3, 2, 2, 2),
|
|
|
|
("2", 2, 2, 2, 4, 4, 4, 3, 2)
|
|
|
|
]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
2016-10-11 01:33:20 -04:00
|
|
|
def test_window_functions_cumulative_sum(self):
|
|
|
|
df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"])
|
|
|
|
from pyspark.sql import functions as F
|
[SPARK-17845] [SQL] More self-evident window function frame boundary API
## What changes were proposed in this pull request?
This patch improves the window function frame boundary API to make it more obvious to read and to use. The two high level changes are:
1. Create Window.currentRow, Window.unboundedPreceding, Window.unboundedFollowing to indicate the special values in frame boundaries. These methods map to the special integral values so we are not breaking backward compatibility here. This change makes the frame boundaries more self-evident (instead of Long.MinValue, it becomes Window.unboundedPreceding).
2. In Python, for any value less than or equal to JVM's Long.MinValue, treat it as Window.unboundedPreceding. For any value larger than or equal to JVM's Long.MaxValue, treat it as Window.unboundedFollowing. Before this change, if the user specifies any value that is less than Long.MinValue but not -sys.maxsize (e.g. -sys.maxsize + 1), the number we pass over to the JVM would overflow, resulting in a frame that does not make sense.
Code example required to specify a frame before this patch:
```
Window.rowsBetween(-Long.MinValue, 0)
```
While the above code should still work, the new way is more obvious to read:
```
Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
```
## How was this patch tested?
- Updated DataFrameWindowSuite (for Scala/Java)
- Updated test_window_functions_cumulative_sum (for Python)
- Renamed DataFrameWindowSuite DataFrameWindowFunctionsSuite to better reflect its purpose
Author: Reynold Xin <rxin@databricks.com>
Closes #15438 from rxin/SPARK-17845.
2016-10-12 19:45:10 -04:00
|
|
|
|
|
|
|
# Test cumulative sum
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0)))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 1), ("two", 3)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
|
|
|
# Test boundary values less than JVM's Long.MinValue and make sure we don't overflow
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0)))
|
2016-10-11 01:33:20 -04:00
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 1), ("two", 3)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
[SPARK-17845] [SQL] More self-evident window function frame boundary API
## What changes were proposed in this pull request?
This patch improves the window function frame boundary API to make it more obvious to read and to use. The two high level changes are:
1. Create Window.currentRow, Window.unboundedPreceding, Window.unboundedFollowing to indicate the special values in frame boundaries. These methods map to the special integral values so we are not breaking backward compatibility here. This change makes the frame boundaries more self-evident (instead of Long.MinValue, it becomes Window.unboundedPreceding).
2. In Python, for any value less than or equal to JVM's Long.MinValue, treat it as Window.unboundedPreceding. For any value larger than or equal to JVM's Long.MaxValue, treat it as Window.unboundedFollowing. Before this change, if the user specifies any value that is less than Long.MinValue but not -sys.maxsize (e.g. -sys.maxsize + 1), the number we pass over to the JVM would overflow, resulting in a frame that does not make sense.
Code example required to specify a frame before this patch:
```
Window.rowsBetween(-Long.MinValue, 0)
```
While the above code should still work, the new way is more obvious to read:
```
Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
```
## How was this patch tested?
- Updated DataFrameWindowSuite (for Scala/Java)
- Updated test_window_functions_cumulative_sum (for Python)
- Renamed DataFrameWindowSuite DataFrameWindowFunctionsSuite to better reflect its purpose
Author: Reynold Xin <rxin@databricks.com>
Closes #15438 from rxin/SPARK-17845.
2016-10-12 19:45:10 -04:00
|
|
|
# Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow
|
|
|
|
frame_end = Window.unboundedFollowing + 1
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end)))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 3), ("two", 2)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
2015-11-09 17:30:37 -05:00
|
|
|
def test_collect_functions(self):
|
2016-05-11 14:24:16 -04:00
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
2015-11-09 17:30:37 -05:00
|
|
|
from pyspark.sql import functions
|
|
|
|
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
|
|
|
|
[1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
|
|
|
|
[1, 1, 1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
|
|
|
|
["1", "2"])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
|
|
|
|
["1", "2", "2", "2"])
|
|
|
|
|
[SPARK-17514] df.take(1) and df.limit(1).collect() should perform the same in Python
## What changes were proposed in this pull request?
In PySpark, `df.take(1)` runs a single-stage job which computes only one partition of the DataFrame, while `df.limit(1).collect()` computes all partitions and runs a two-stage job. This difference in performance is confusing.
The reason why `limit(1).collect()` is so much slower is that `collect()` internally maps to `df.rdd.<some-pyspark-conversions>.toLocalIterator`, which causes Spark SQL to build a query where a global limit appears in the middle of the plan; this, in turn, ends up being executed inefficiently because limits in the middle of plans are now implemented by repartitioning to a single task rather than by running a `take()` job on the driver (this was done in #7334, a patch which was a prerequisite to allowing partition-local limits to be pushed beneath unions, etc.).
In order to fix this performance problem I think that we should generalize the fix from SPARK-10731 / #8876 so that `DataFrame.collect()` also delegates to the Scala implementation and shares the same performance properties. This patch modifies `DataFrame.collect()` to first collect all results to the driver and then pass them to Python, allowing this query to be planned using Spark's `CollectLimit` optimizations.
## How was this patch tested?
Added a regression test in `sql/tests.py` which asserts that the expected number of jobs, stages, and tasks are run for both queries.
Author: Josh Rosen <joshrosen@databricks.com>
Closes #15068 from JoshRosen/pyspark-collect-limit.
2016-09-14 13:10:01 -04:00
|
|
|
def test_limit_and_take(self):
|
|
|
|
df = self.spark.range(1, 1000, numPartitions=10)
|
|
|
|
|
|
|
|
def assert_runs_only_one_job_stage_and_task(job_group_name, f):
|
|
|
|
tracker = self.sc.statusTracker()
|
|
|
|
self.sc.setJobGroup(job_group_name, description="")
|
|
|
|
f()
|
|
|
|
jobs = tracker.getJobIdsForGroup(job_group_name)
|
|
|
|
self.assertEqual(1, len(jobs))
|
|
|
|
stages = tracker.getJobInfo(jobs[0]).stageIds
|
|
|
|
self.assertEqual(1, len(stages))
|
|
|
|
self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks)
|
|
|
|
|
|
|
|
# Regression test for SPARK-10731: take should delegate to Scala implementation
|
|
|
|
assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1))
|
|
|
|
# Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n)
|
|
|
|
assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect())
|
|
|
|
|
2017-02-07 09:50:30 -05:00
|
|
|
def test_datetime_functions(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
from datetime import date, datetime
|
|
|
|
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
|
|
|
|
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
|
|
|
|
self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)'])
|
|
|
|
|
2016-12-02 20:39:28 -05:00
|
|
|
@unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't support mocking")
|
|
|
|
def test_unbounded_frames(self):
|
|
|
|
from unittest.mock import patch
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
from pyspark.sql import window
|
|
|
|
import importlib
|
|
|
|
|
|
|
|
df = self.spark.range(0, 3)
|
|
|
|
|
|
|
|
def rows_frame_match():
|
|
|
|
return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select(
|
|
|
|
F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize))
|
|
|
|
).columns[0]
|
|
|
|
|
|
|
|
def range_frame_match():
|
|
|
|
return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select(
|
|
|
|
F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize))
|
|
|
|
).columns[0]
|
|
|
|
|
|
|
|
with patch("sys.maxsize", 2 ** 31 - 1):
|
|
|
|
importlib.reload(window)
|
|
|
|
self.assertTrue(rows_frame_match())
|
|
|
|
self.assertTrue(range_frame_match())
|
|
|
|
|
|
|
|
with patch("sys.maxsize", 2 ** 63 - 1):
|
|
|
|
importlib.reload(window)
|
|
|
|
self.assertTrue(rows_frame_match())
|
|
|
|
self.assertTrue(range_frame_match())
|
|
|
|
|
|
|
|
with patch("sys.maxsize", 2 ** 127 - 1):
|
|
|
|
importlib.reload(window)
|
|
|
|
self.assertTrue(rows_frame_match())
|
|
|
|
self.assertTrue(range_frame_match())
|
|
|
|
|
|
|
|
importlib.reload(window)
|
2015-08-14 16:55:29 -04:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
if __name__ == "__main__":
|
2016-01-20 14:11:10 -05:00
|
|
|
from pyspark.sql.tests import *
|
2015-10-22 18:27:11 -04:00
|
|
|
if xmlrunner:
|
|
|
|
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
|
|
|
|
else:
|
|
|
|
unittest.main()
|