2015-07-01 19:43:18 -04:00
|
|
|
# -*- encoding: utf-8 -*-
|
2015-02-03 19:01:56 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
"""
|
|
|
|
Unit tests for pyspark.sql; additional tests are implemented as doctests in
|
|
|
|
individual modules.
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import pydoc
|
|
|
|
import shutil
|
|
|
|
import tempfile
|
2015-02-27 23:07:17 -05:00
|
|
|
import pickle
|
2015-04-01 20:23:57 -04:00
|
|
|
import functools
|
2015-06-11 04:00:41 -04:00
|
|
|
import time
|
2015-04-21 03:08:18 -04:00
|
|
|
import datetime
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-02-17 18:44:37 -05:00
|
|
|
import py4j
|
2015-10-22 18:27:11 -04:00
|
|
|
try:
|
|
|
|
import xmlrunner
|
|
|
|
except ImportError:
|
|
|
|
xmlrunner = None
|
2015-02-17 18:44:37 -05:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
if sys.version_info[:2] <= (2, 6):
|
|
|
|
try:
|
|
|
|
import unittest2 as unittest
|
|
|
|
except ImportError:
|
|
|
|
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
|
|
|
|
sys.exit(1)
|
|
|
|
else:
|
|
|
|
import unittest
|
|
|
|
|
2015-02-24 23:51:55 -05:00
|
|
|
from pyspark.sql import SQLContext, HiveContext, Column, Row
|
|
|
|
from pyspark.sql.types import *
|
|
|
|
from pyspark.sql.types import UserDefinedType, _infer_type
|
2015-02-03 19:01:56 -05:00
|
|
|
from pyspark.tests import ReusedPySparkTestCase
|
2015-07-19 03:32:56 -04:00
|
|
|
from pyspark.sql.functions import UserDefinedFunction, sha2
|
2015-05-23 11:30:05 -04:00
|
|
|
from pyspark.sql.window import Window
|
2016-03-28 15:31:12 -04:00
|
|
|
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
class UTCOffsetTimezone(datetime.tzinfo):
|
|
|
|
"""
|
|
|
|
Specifies timezone in UTC offset
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, offset=0):
|
|
|
|
self.ZERO = datetime.timedelta(hours=offset)
|
2015-06-11 04:00:41 -04:00
|
|
|
|
|
|
|
def utcoffset(self, dt):
|
|
|
|
return self.ZERO
|
|
|
|
|
|
|
|
def dst(self, dt):
|
|
|
|
return self.ZERO
|
|
|
|
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
class ExamplePointUDT(UserDefinedType):
|
|
|
|
"""
|
|
|
|
User-defined type (UDT) for ExamplePoint.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def sqlType(self):
|
|
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def module(cls):
|
2015-07-30 01:30:49 -04:00
|
|
|
return 'pyspark.sql.tests'
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def scalaUDT(cls):
|
|
|
|
return 'org.apache.spark.sql.test.ExamplePointUDT'
|
|
|
|
|
|
|
|
def serialize(self, obj):
|
|
|
|
return [obj.x, obj.y]
|
|
|
|
|
|
|
|
def deserialize(self, datum):
|
|
|
|
return ExamplePoint(datum[0], datum[1])
|
|
|
|
|
|
|
|
|
|
|
|
class ExamplePoint:
|
|
|
|
"""
|
|
|
|
An example class to demonstrate UDT in Scala, Java, and Python.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__UDT__ = ExamplePointUDT()
|
|
|
|
|
|
|
|
def __init__(self, x, y):
|
|
|
|
self.x = x
|
|
|
|
self.y = y
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "ExamplePoint(%s,%s)" % (self.x, self.y)
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return "(%s,%s)" % (self.x, self.y)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2015-07-30 01:30:49 -04:00
|
|
|
return isinstance(other, self.__class__) and \
|
2015-02-03 19:01:56 -05:00
|
|
|
other.x == self.x and other.y == self.y
|
|
|
|
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
class PythonOnlyUDT(UserDefinedType):
|
|
|
|
"""
|
|
|
|
User-defined type (UDT) for ExamplePoint.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def sqlType(self):
|
|
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def module(cls):
|
|
|
|
return '__main__'
|
|
|
|
|
|
|
|
def serialize(self, obj):
|
|
|
|
return [obj.x, obj.y]
|
|
|
|
|
|
|
|
def deserialize(self, datum):
|
|
|
|
return PythonOnlyPoint(datum[0], datum[1])
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def foo():
|
|
|
|
pass
|
|
|
|
|
|
|
|
@property
|
|
|
|
def props(self):
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
class PythonOnlyPoint(ExamplePoint):
|
|
|
|
"""
|
|
|
|
An example class to demonstrate UDT in only Python
|
|
|
|
"""
|
|
|
|
__UDT__ = PythonOnlyUDT()
|
|
|
|
|
|
|
|
|
2015-08-26 19:04:44 -04:00
|
|
|
class MyObject(object):
|
|
|
|
def __init__(self, key, value):
|
|
|
|
self.key = key
|
|
|
|
self.value = value
|
|
|
|
|
|
|
|
|
2015-02-27 23:07:17 -05:00
|
|
|
class DataTypeTests(unittest.TestCase):
|
|
|
|
# regression test for SPARK-6055
|
|
|
|
def test_data_type_eq(self):
|
|
|
|
lt = LongType()
|
|
|
|
lt2 = pickle.loads(pickle.dumps(LongType()))
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(lt, lt2)
|
2015-02-27 23:07:17 -05:00
|
|
|
|
2015-05-31 22:55:57 -04:00
|
|
|
# regression test for SPARK-7978
|
|
|
|
def test_decimal_type(self):
|
|
|
|
t1 = DecimalType()
|
|
|
|
t2 = DecimalType(10, 2)
|
|
|
|
self.assertTrue(t2 is not t1)
|
|
|
|
self.assertNotEqual(t1, t2)
|
|
|
|
t3 = DecimalType(8)
|
|
|
|
self.assertNotEqual(t2, t3)
|
|
|
|
|
2015-09-01 17:58:49 -04:00
|
|
|
# regression test for SPARK-10392
|
|
|
|
def test_datetype_equal_zero(self):
|
|
|
|
dt = DateType()
|
|
|
|
self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
|
|
|
|
|
2015-02-27 23:07:17 -05:00
|
|
|
|
2015-10-19 19:18:20 -04:00
|
|
|
class SQLContextTests(ReusedPySparkTestCase):
|
|
|
|
def test_get_or_create(self):
|
|
|
|
sqlCtx = SQLContext.getOrCreate(self.sc)
|
|
|
|
self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx)
|
|
|
|
|
|
|
|
def test_new_session(self):
|
|
|
|
sqlCtx = SQLContext.getOrCreate(self.sc)
|
|
|
|
sqlCtx.setConf("test_key", "a")
|
|
|
|
sqlCtx2 = sqlCtx.newSession()
|
|
|
|
sqlCtx2.setConf("test_key", "b")
|
|
|
|
self.assertEqual(sqlCtx.getConf("test_key", ""), "a")
|
|
|
|
self.assertEqual(sqlCtx2.getConf("test_key", ""), "b")
|
|
|
|
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
class SQLTests(ReusedPySparkTestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ReusedPySparkTestCase.setUpClass()
|
|
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
|
|
|
os.unlink(cls.tempdir.name)
|
|
|
|
cls.sqlCtx = SQLContext(cls.sc)
|
|
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
2015-04-21 20:49:55 -04:00
|
|
|
rdd = cls.sc.parallelize(cls.testData, 2)
|
2015-02-14 02:03:22 -05:00
|
|
|
cls.df = rdd.toDF()
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ReusedPySparkTestCase.tearDownClass()
|
|
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
|
2015-08-08 11:38:18 -04:00
|
|
|
def test_row_should_be_read_only(self):
|
|
|
|
row = Row(a=1, b=2)
|
|
|
|
self.assertEqual(1, row.a)
|
|
|
|
|
|
|
|
def foo():
|
|
|
|
row.a = 3
|
|
|
|
self.assertRaises(Exception, foo)
|
|
|
|
|
|
|
|
row2 = self.sqlCtx.range(10).first()
|
|
|
|
self.assertEqual(0, row2.id)
|
|
|
|
|
|
|
|
def foo2():
|
|
|
|
row2.id = 2
|
|
|
|
self.assertRaises(Exception, foo2)
|
|
|
|
|
2015-05-19 00:43:12 -04:00
|
|
|
def test_range(self):
|
|
|
|
self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
|
|
|
|
self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
|
|
|
|
self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
|
2015-06-03 14:28:18 -04:00
|
|
|
self.assertEqual(self.sqlCtx.range(-2).count(), 0)
|
|
|
|
self.assertEqual(self.sqlCtx.range(3).count(), 3)
|
2015-05-19 00:43:12 -04:00
|
|
|
|
2015-07-09 17:43:38 -04:00
|
|
|
def test_duplicated_column_names(self):
|
|
|
|
df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"])
|
|
|
|
row = df.select('*').first()
|
|
|
|
self.assertEqual(1, row[0])
|
|
|
|
self.assertEqual(2, row[1])
|
|
|
|
self.assertEqual("Row(c=1, c=2)", str(row))
|
|
|
|
# Cannot access columns
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
|
|
|
|
self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
|
|
|
|
|
2015-05-14 22:49:44 -04:00
|
|
|
def test_explode(self):
|
|
|
|
from pyspark.sql.functions import explode
|
|
|
|
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
|
|
|
|
rdd = self.sc.parallelize(d)
|
|
|
|
data = self.sqlCtx.createDataFrame(rdd)
|
|
|
|
|
|
|
|
result = data.select(explode(data.intlist).alias("a")).select("a").collect()
|
|
|
|
self.assertEqual(result[0][0], 1)
|
|
|
|
self.assertEqual(result[1][0], 2)
|
|
|
|
self.assertEqual(result[2][0], 3)
|
|
|
|
|
|
|
|
result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
|
|
|
|
self.assertEqual(result[0][0], "a")
|
|
|
|
self.assertEqual(result[0][1], "b")
|
|
|
|
|
2015-06-23 18:51:16 -04:00
|
|
|
def test_and_in_expression(self):
|
|
|
|
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
|
|
|
|
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
|
|
|
|
self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
|
|
|
|
self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
|
|
|
|
self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
|
|
|
|
self.assertRaises(ValueError, lambda: not self.df.key == 1)
|
|
|
|
|
2015-04-01 20:23:57 -04:00
|
|
|
def test_udf_with_callable(self):
|
|
|
|
d = [Row(number=i, squared=i**2) for i in range(10)]
|
|
|
|
rdd = self.sc.parallelize(d)
|
|
|
|
data = self.sqlCtx.createDataFrame(rdd)
|
|
|
|
|
|
|
|
class PlusFour:
|
|
|
|
def __call__(self, col):
|
|
|
|
if col is not None:
|
|
|
|
return col + 4
|
|
|
|
|
|
|
|
call = PlusFour()
|
|
|
|
pudf = UserDefinedFunction(call, LongType())
|
|
|
|
res = data.select(pudf(data['number']).alias('plus_four'))
|
|
|
|
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
|
|
|
|
|
|
|
|
def test_udf_with_partial_function(self):
|
|
|
|
d = [Row(number=i, squared=i**2) for i in range(10)]
|
|
|
|
rdd = self.sc.parallelize(d)
|
|
|
|
data = self.sqlCtx.createDataFrame(rdd)
|
|
|
|
|
|
|
|
def some_func(col, param):
|
|
|
|
if col is not None:
|
|
|
|
return col + param
|
|
|
|
|
|
|
|
pfunc = functools.partial(some_func, param=4)
|
|
|
|
pudf = UserDefinedFunction(pfunc, LongType())
|
|
|
|
res = data.select(pudf(data['number']).alias('plus_four'))
|
|
|
|
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_udf(self):
|
|
|
|
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
|
|
|
|
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
|
|
|
|
self.assertEqual(row[0], 5)
|
|
|
|
|
|
|
|
def test_udf2(self):
|
|
|
|
self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
|
2015-02-10 22:40:12 -05:00
|
|
|
self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
|
2015-02-03 19:01:56 -05:00
|
|
|
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
|
|
|
|
self.assertEqual(4, res[0])
|
|
|
|
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
def test_chained_udf(self):
|
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
2016-03-29 18:06:29 -04:00
|
|
|
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
|
|
|
|
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
|
|
|
|
self.assertEqual(row[0], 2)
|
|
|
|
[row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
|
|
|
|
self.assertEqual(row[0], 4)
|
|
|
|
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
|
|
|
|
self.assertEqual(row[0], 6)
|
|
|
|
|
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
2016-03-31 19:40:20 -04:00
|
|
|
def test_multiple_udfs(self):
|
|
|
|
self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
|
|
|
|
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
|
|
|
|
self.assertEqual(tuple(row), (2, 4))
|
|
|
|
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
|
|
|
|
self.assertEqual(tuple(row), (4, 12))
|
|
|
|
self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
|
|
|
|
[row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
|
|
|
|
self.assertEqual(tuple(row), (6, 5))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_udf_with_array_type(self):
|
2015-04-16 19:20:57 -04:00
|
|
|
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
|
2015-02-03 19:01:56 -05:00
|
|
|
rdd = self.sc.parallelize(d)
|
2015-02-10 22:40:12 -05:00
|
|
|
self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
|
2015-02-03 19:01:56 -05:00
|
|
|
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
|
|
|
|
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
|
|
|
|
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
|
2015-04-16 19:20:57 -04:00
|
|
|
self.assertEqual(list(range(3)), l1)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, l2)
|
|
|
|
|
|
|
|
def test_broadcast_in_udf(self):
|
|
|
|
bar = {"a": "aa", "b": "bb", "c": "abc"}
|
|
|
|
foo = self.sc.broadcast(bar)
|
|
|
|
self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
|
|
|
|
[res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
|
|
|
|
self.assertEqual("abc", res[0])
|
|
|
|
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
|
|
|
|
self.assertEqual("", res[0])
|
|
|
|
|
2016-04-04 13:56:26 -04:00
|
|
|
def test_udf_with_aggregate_function(self):
|
|
|
|
df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
from pyspark.sql.functions import udf, col
|
|
|
|
from pyspark.sql.types import BooleanType
|
|
|
|
|
|
|
|
my_filter = udf(lambda a: a == 1, BooleanType())
|
|
|
|
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
|
|
|
|
self.assertEqual(sel.collect(), [Row(key=1)])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_basic_functions(self):
|
|
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
2016-01-04 21:02:38 -05:00
|
|
|
df = self.sqlCtx.read.json(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
df.count()
|
|
|
|
df.collect()
|
2015-02-14 02:03:22 -05:00
|
|
|
df.schema
|
2015-02-03 19:01:56 -05:00
|
|
|
|
|
|
|
# cache and checkpoint
|
|
|
|
self.assertFalse(df.is_cached)
|
|
|
|
df.persist()
|
|
|
|
df.unpersist()
|
|
|
|
df.cache()
|
|
|
|
self.assertTrue(df.is_cached)
|
|
|
|
self.assertEqual(2, df.count())
|
|
|
|
|
|
|
|
df.registerTempTable("temp")
|
|
|
|
df = self.sqlCtx.sql("select foo from temp")
|
|
|
|
df.count()
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
def test_apply_schema_to_row(self):
|
2016-01-04 21:02:38 -05:00
|
|
|
df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""]))
|
2016-03-02 18:26:34 -05:00
|
|
|
df2 = self.sqlCtx.createDataFrame(df.rdd.map(lambda x: x), df.schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(df.collect(), df2.collect())
|
|
|
|
|
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
|
2015-12-30 14:14:47 -05:00
|
|
|
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
|
|
|
|
self.assertEqual(10, df3.count())
|
|
|
|
|
|
|
|
def test_infer_schema_to_local(self):
|
|
|
|
input = [{"a": 1}, {"b": "coffee"}]
|
|
|
|
rdd = self.sc.parallelize(input)
|
|
|
|
df = self.sqlCtx.createDataFrame(input)
|
|
|
|
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
|
|
|
|
self.assertEqual(df.schema, df2.schema)
|
|
|
|
|
2016-01-03 20:04:35 -05:00
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
|
2015-02-14 02:03:22 -05:00
|
|
|
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(10, df3.count())
|
|
|
|
|
2016-01-24 22:40:34 -05:00
|
|
|
def test_create_dataframe_schema_mismatch(self):
|
|
|
|
input = [Row(a=1)]
|
|
|
|
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
|
|
|
|
schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
|
|
|
|
df = self.sqlCtx.createDataFrame(rdd, schema)
|
2016-03-08 17:00:03 -05:00
|
|
|
self.assertRaises(Exception, lambda: df.show())
|
2016-01-24 22:40:34 -05:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_serialize_nested_array_and_map(self):
|
|
|
|
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2015-02-10 22:40:12 -05:00
|
|
|
df = self.sqlCtx.createDataFrame(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
row = df.head()
|
|
|
|
self.assertEqual(1, len(row.l))
|
|
|
|
self.assertEqual(1, row.l[0].a)
|
|
|
|
self.assertEqual("2", row.d["key"].d)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
l = df.rdd.map(lambda x: x.l).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, len(l))
|
|
|
|
self.assertEqual('s', l[0].b)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
d = df.rdd.map(lambda x: x.d).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, len(d))
|
|
|
|
self.assertEqual(1.0, d["key"].c)
|
|
|
|
|
2016-03-02 18:26:34 -05:00
|
|
|
row = df.rdd.map(lambda x: x.d["key"]).first()
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1.0, row.c)
|
|
|
|
self.assertEqual("2", row.d)
|
|
|
|
|
|
|
|
def test_infer_schema(self):
|
2015-02-20 18:35:05 -05:00
|
|
|
d = [Row(l=[], d={}, s=None),
|
2015-02-03 19:01:56 -05:00
|
|
|
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
|
|
|
|
rdd = self.sc.parallelize(d)
|
2015-02-10 22:40:12 -05:00
|
|
|
df = self.sqlCtx.createDataFrame(rdd)
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual([], df.rdd.map(lambda r: r.l).first())
|
|
|
|
self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
|
2015-02-03 19:01:56 -05:00
|
|
|
df.registerTempTable("test")
|
|
|
|
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
|
|
|
|
self.assertEqual(1, result.head()[0])
|
|
|
|
|
2015-02-11 15:13:16 -05:00
|
|
|
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
|
2015-02-14 02:03:22 -05:00
|
|
|
self.assertEqual(df.schema, df2.schema)
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
|
|
|
|
self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
|
2015-02-03 19:01:56 -05:00
|
|
|
df2.registerTempTable("test2")
|
|
|
|
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
|
|
|
|
self.assertEqual(1, result.head()[0])
|
|
|
|
|
2015-02-24 23:51:55 -05:00
|
|
|
def test_infer_nested_schema(self):
|
|
|
|
NestedRow = Row("f1", "f2")
|
|
|
|
nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
|
|
|
|
NestedRow([2, 3], {"row2": 2.0})])
|
2016-01-04 21:02:38 -05:00
|
|
|
df = self.sqlCtx.createDataFrame(nestedRdd1)
|
2015-02-24 23:51:55 -05:00
|
|
|
self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
|
|
|
|
|
|
|
|
nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
|
|
|
|
NestedRow([[2, 3], [3, 4]], [2, 3])])
|
2016-01-04 21:02:38 -05:00
|
|
|
df = self.sqlCtx.createDataFrame(nestedRdd2)
|
2015-02-24 23:51:55 -05:00
|
|
|
self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
|
|
|
|
|
|
|
|
from collections import namedtuple
|
|
|
|
CustomRow = namedtuple('CustomRow', 'field1 field2')
|
|
|
|
rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
|
|
|
|
CustomRow(field1=2, field2="row2"),
|
|
|
|
CustomRow(field1=3, field2="row3")])
|
2016-01-04 21:02:38 -05:00
|
|
|
df = self.sqlCtx.createDataFrame(rdd)
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
|
2015-02-24 23:51:55 -05:00
|
|
|
|
2015-08-26 19:04:44 -04:00
|
|
|
def test_create_dataframe_from_objects(self):
|
|
|
|
data = [MyObject(1, "1"), MyObject(2, "2")]
|
|
|
|
df = self.sqlCtx.createDataFrame(data)
|
|
|
|
self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
|
|
|
|
self.assertEqual(df.first(), Row(key=1, value="1"))
|
|
|
|
|
2015-07-20 15:00:48 -04:00
|
|
|
def test_select_null_literal(self):
|
|
|
|
df = self.sqlCtx.sql("select null as col")
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(Row(col=None), df.first())
|
2015-07-20 15:00:48 -04:00
|
|
|
|
2015-02-24 23:51:55 -05:00
|
|
|
def test_apply_schema(self):
|
|
|
|
from datetime import date, datetime
|
2015-04-16 19:20:57 -04:00
|
|
|
rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
|
2015-02-24 23:51:55 -05:00
|
|
|
date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
|
|
|
|
{"a": 1}, (2,), [1, 2, 3], None)])
|
|
|
|
schema = StructType([
|
|
|
|
StructField("byte1", ByteType(), False),
|
|
|
|
StructField("byte2", ByteType(), False),
|
|
|
|
StructField("short1", ShortType(), False),
|
|
|
|
StructField("short2", ShortType(), False),
|
|
|
|
StructField("int1", IntegerType(), False),
|
|
|
|
StructField("float1", FloatType(), False),
|
|
|
|
StructField("date1", DateType(), False),
|
|
|
|
StructField("time1", TimestampType(), False),
|
|
|
|
StructField("map1", MapType(StringType(), IntegerType(), False), False),
|
|
|
|
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
|
|
|
|
StructField("list1", ArrayType(ByteType(), False), False),
|
|
|
|
StructField("null1", DoubleType(), True)])
|
2015-04-17 12:29:27 -04:00
|
|
|
df = self.sqlCtx.createDataFrame(rdd, schema)
|
2016-03-02 18:26:34 -05:00
|
|
|
results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
|
|
|
|
x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
|
2015-02-24 23:51:55 -05:00
|
|
|
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
|
|
|
|
datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
|
|
|
|
self.assertEqual(r, results.first())
|
|
|
|
|
|
|
|
df.registerTempTable("table2")
|
|
|
|
r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
|
|
|
|
"short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
|
|
|
|
"float1 + 1.5 as float1 FROM table2").first()
|
|
|
|
|
|
|
|
self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
|
|
|
|
|
|
|
|
from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
|
|
|
|
rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
|
|
|
|
{"a": 1}, (2,), [1, 2, 3])])
|
|
|
|
abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
|
|
|
|
schema = _parse_schema_abstract(abstract)
|
|
|
|
typedSchema = _infer_schema_type(rdd.first(), schema)
|
2015-04-21 20:49:55 -04:00
|
|
|
df = self.sqlCtx.createDataFrame(rdd, typedSchema)
|
2015-02-24 23:51:55 -05:00
|
|
|
r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
|
|
|
|
self.assertEqual(r, tuple(df.first()))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_struct_in_map(self):
|
|
|
|
d = [Row(m={Row(i=1): Row(s="")})]
|
2015-02-14 02:03:22 -05:00
|
|
|
df = self.sc.parallelize(d).toDF()
|
2015-04-16 19:20:57 -04:00
|
|
|
k, v = list(df.head().m.items())[0]
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertEqual(1, k.i)
|
|
|
|
self.assertEqual("", v.s)
|
|
|
|
|
|
|
|
def test_convert_row_to_dict(self):
|
|
|
|
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
|
|
|
|
self.assertEqual(1, row.asDict()['l'][0].a)
|
2015-02-14 02:03:22 -05:00
|
|
|
df = self.sc.parallelize([row]).toDF()
|
2015-02-03 19:01:56 -05:00
|
|
|
df.registerTempTable("test")
|
|
|
|
row = self.sqlCtx.sql("select l, d from test").head()
|
|
|
|
self.assertEqual(1, row.asDict()["l"][0].a)
|
|
|
|
self.assertEqual(1.0, row.asDict()['d']['key'].c)
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
def test_udt(self):
|
|
|
|
from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
|
|
|
|
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
|
|
|
|
|
|
|
|
def check_datatype(datatype):
|
|
|
|
pickled = pickle.loads(pickle.dumps(datatype))
|
|
|
|
assert datatype == pickled
|
|
|
|
scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
|
|
|
|
python_datatype = _parse_datatype_json_string(scala_datatype.json())
|
|
|
|
assert datatype == python_datatype
|
|
|
|
|
|
|
|
check_datatype(ExamplePointUDT())
|
|
|
|
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
|
|
|
check_datatype(structtype_with_udt)
|
|
|
|
p = ExamplePoint(1.0, 2.0)
|
|
|
|
self.assertEqual(_infer_type(p), ExamplePointUDT())
|
|
|
|
_verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
|
|
|
|
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
|
|
|
|
|
|
|
|
check_datatype(PythonOnlyUDT())
|
|
|
|
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", PythonOnlyUDT(), False)])
|
|
|
|
check_datatype(structtype_with_udt)
|
|
|
|
p = PythonOnlyPoint(1.0, 2.0)
|
|
|
|
self.assertEqual(_infer_type(p), PythonOnlyUDT())
|
|
|
|
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
|
|
|
|
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_infer_schema_with_udt(self):
|
2015-02-09 23:49:22 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
df = self.sqlCtx.createDataFrame([row])
|
2015-02-14 02:03:22 -05:00
|
|
|
schema = df.schema
|
2015-02-03 19:01:56 -05:00
|
|
|
field = [f for f in schema.fields if f.name == "point"][0]
|
|
|
|
self.assertEqual(type(field.dataType), ExamplePointUDT)
|
|
|
|
df.registerTempTable("labeled_point")
|
|
|
|
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
|
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
|
|
|
df = self.sqlCtx.createDataFrame([row])
|
|
|
|
schema = df.schema
|
|
|
|
field = [f for f in schema.fields if f.name == "point"][0]
|
|
|
|
self.assertEqual(type(field.dataType), PythonOnlyUDT)
|
|
|
|
df.registerTempTable("labeled_point")
|
|
|
|
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
|
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_apply_schema_with_udt(self):
|
2015-02-09 23:49:22 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = (1.0, ExamplePoint(1.0, 2.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
2015-07-30 01:30:49 -04:00
|
|
|
df = self.sqlCtx.createDataFrame([row], schema)
|
2015-02-03 19:01:56 -05:00
|
|
|
point = df.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = (1.0, PythonOnlyPoint(1.0, 2.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", PythonOnlyUDT(), False)])
|
|
|
|
df = self.sqlCtx.createDataFrame([row], schema)
|
|
|
|
point = df.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
|
2015-07-09 17:43:38 -04:00
|
|
|
def test_udf_with_udt(self):
|
2015-07-20 15:14:47 -04:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-07-09 17:43:38 -04:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
df = self.sqlCtx.createDataFrame([row])
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
|
2015-07-09 17:43:38 -04:00
|
|
|
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
|
|
|
|
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
|
2015-07-20 15:14:47 -04:00
|
|
|
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
|
|
|
|
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
|
2015-07-09 17:43:38 -04:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
|
|
|
df = self.sqlCtx.createDataFrame([row])
|
2016-03-02 18:26:34 -05:00
|
|
|
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
|
2015-07-30 01:30:49 -04:00
|
|
|
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
|
|
|
|
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
|
|
|
|
udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
|
|
|
|
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_parquet_with_udt(self):
|
2015-07-30 01:30:49 -04:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
2015-02-03 19:01:56 -05:00
|
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
df0 = self.sqlCtx.createDataFrame([row])
|
2015-02-03 19:01:56 -05:00
|
|
|
output_dir = os.path.join(self.tempdir.name, "labeled_point")
|
2015-07-30 01:30:49 -04:00
|
|
|
df0.write.parquet(output_dir)
|
2016-01-04 21:02:38 -05:00
|
|
|
df1 = self.sqlCtx.read.parquet(output_dir)
|
2015-02-03 19:01:56 -05:00
|
|
|
point = df1.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-07-30 01:30:49 -04:00
|
|
|
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
|
|
|
|
df0 = self.sqlCtx.createDataFrame([row])
|
|
|
|
df0.write.parquet(output_dir, mode='overwrite')
|
2016-01-04 21:02:38 -05:00
|
|
|
df1 = self.sqlCtx.read.parquet(output_dir)
|
2015-07-30 01:30:49 -04:00
|
|
|
point = df1.head().point
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
|
2015-07-30 01:30:49 -04:00
|
|
|
|
2016-03-25 01:34:55 -04:00
|
|
|
def test_union_with_udt(self):
|
2016-02-21 19:58:17 -05:00
|
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
|
|
|
row1 = (1.0, ExamplePoint(1.0, 2.0))
|
|
|
|
row2 = (2.0, ExamplePoint(3.0, 4.0))
|
|
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
|
|
StructField("point", ExamplePointUDT(), False)])
|
|
|
|
df1 = self.sqlCtx.createDataFrame([row1], schema)
|
|
|
|
df2 = self.sqlCtx.createDataFrame([row2], schema)
|
|
|
|
|
2016-03-25 01:34:55 -04:00
|
|
|
result = df1.union(df2).orderBy("label").collect()
|
2016-02-21 19:58:17 -05:00
|
|
|
self.assertEqual(
|
|
|
|
result,
|
|
|
|
[
|
|
|
|
Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
|
|
|
|
Row(label=2.0, point=ExamplePoint(3.0, 4.0))
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_column_operators(self):
|
|
|
|
ci = self.df.key
|
|
|
|
cs = self.df.value
|
|
|
|
c = ci == cs
|
|
|
|
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
|
2015-09-11 18:19:04 -04:00
|
|
|
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in rcc))
|
2015-06-23 18:51:16 -04:00
|
|
|
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
|
2015-02-03 19:01:56 -05:00
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in cb))
|
|
|
|
cbool = (ci & ci), (ci | ci), (~ci)
|
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in cbool))
|
|
|
|
css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
|
|
|
|
self.assertTrue(all(isinstance(c, Column) for c in css))
|
|
|
|
self.assertTrue(isinstance(ci.cast(LongType()), Column))
|
|
|
|
|
|
|
|
def test_column_select(self):
|
|
|
|
df = self.df
|
|
|
|
self.assertEqual(self.testData, df.select("*").collect())
|
|
|
|
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
|
|
|
|
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
|
|
|
|
|
2015-05-02 02:43:24 -04:00
|
|
|
def test_freqItems(self):
|
|
|
|
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
|
|
|
|
df = self.sc.parallelize(vals).toDF()
|
|
|
|
items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
|
|
|
|
self.assertTrue(1 in items[0])
|
|
|
|
self.assertTrue(-2.0 in items[1])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_aggregator(self):
|
|
|
|
df = self.df
|
|
|
|
g = df.groupBy()
|
|
|
|
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
|
|
|
|
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
|
|
|
|
|
2015-02-14 02:03:22 -05:00
|
|
|
from pyspark.sql import functions
|
|
|
|
self.assertEqual((0, u'99'),
|
|
|
|
tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
|
|
|
|
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
|
|
|
|
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2016-01-31 16:56:13 -05:00
|
|
|
def test_first_last_ignorenulls(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
df = self.sqlCtx.range(0, 100)
|
|
|
|
df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
|
|
|
|
df3 = df2.select(functions.first(df2.id, False).alias('a'),
|
|
|
|
functions.first(df2.id, True).alias('b'),
|
|
|
|
functions.last(df2.id, False).alias('c'),
|
|
|
|
functions.last(df2.id, True).alias('d'))
|
|
|
|
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
|
|
|
|
|
2016-02-25 02:15:36 -05:00
|
|
|
def test_approxQuantile(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF()
|
|
|
|
aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
|
|
|
|
self.assertTrue(isinstance(aq, list))
|
|
|
|
self.assertEqual(len(aq), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aq))
|
|
|
|
|
2015-05-04 00:44:39 -04:00
|
|
|
def test_corr(self):
|
|
|
|
import math
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
|
|
|
|
corr = df.stat.corr("a", "b")
|
|
|
|
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
|
|
|
|
|
2015-05-01 16:29:17 -04:00
|
|
|
def test_cov(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
|
|
|
cov = df.stat.cov("a", "b")
|
|
|
|
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
|
|
|
|
|
2015-05-04 20:02:49 -04:00
|
|
|
def test_crosstab(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
|
|
|
|
ct = df.stat.crosstab("a", "b").collect()
|
|
|
|
ct = sorted(ct, key=lambda x: x[0])
|
|
|
|
for i, row in enumerate(ct):
|
|
|
|
self.assertEqual(row[0], str(i))
|
|
|
|
self.assertTrue(row[1], 1)
|
|
|
|
self.assertTrue(row[2], 1)
|
|
|
|
|
2015-04-29 03:09:24 -04:00
|
|
|
def test_math_functions(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
2015-05-06 01:56:01 -04:00
|
|
|
from pyspark.sql import functions
|
2015-04-29 03:09:24 -04:00
|
|
|
import math
|
|
|
|
|
|
|
|
def get_values(l):
|
|
|
|
return [j[0] for j in l]
|
|
|
|
|
|
|
|
def assert_close(a, b):
|
|
|
|
c = get_values(b)
|
|
|
|
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
|
|
|
|
return sum(diff) == len(a)
|
|
|
|
assert_close([math.cos(i) for i in range(10)],
|
|
|
|
df.select(functions.cos(df.a)).collect())
|
|
|
|
assert_close([math.cos(i) for i in range(10)],
|
|
|
|
df.select(functions.cos("a")).collect())
|
|
|
|
assert_close([math.sin(i) for i in range(10)],
|
|
|
|
df.select(functions.sin(df.a)).collect())
|
|
|
|
assert_close([math.sin(i) for i in range(10)],
|
|
|
|
df.select(functions.sin(df['a'])).collect())
|
|
|
|
assert_close([math.pow(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, df.b)).collect())
|
|
|
|
assert_close([math.pow(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, 2)).collect())
|
|
|
|
assert_close([math.pow(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, 2.0)).collect())
|
|
|
|
assert_close([math.hypot(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.hypot(df.a, df.b)).collect())
|
|
|
|
|
2015-05-01 00:56:03 -04:00
|
|
|
def test_rand_functions(self):
|
|
|
|
df = self.df
|
|
|
|
from pyspark.sql import functions
|
|
|
|
rnd = df.select('key', functions.rand()).collect()
|
|
|
|
for row in rnd:
|
|
|
|
assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
|
|
|
|
rndn = df.select('key', functions.randn(5)).collect()
|
|
|
|
for row in rndn:
|
|
|
|
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
|
|
|
|
|
2015-08-06 20:03:14 -04:00
|
|
|
# If the specified seed is 0, we should use it.
|
|
|
|
# https://issues.apache.org/jira/browse/SPARK-9691
|
|
|
|
rnd1 = df.select('key', functions.rand(0)).collect()
|
|
|
|
rnd2 = df.select('key', functions.rand(0)).collect()
|
|
|
|
self.assertEqual(sorted(rnd1), sorted(rnd2))
|
|
|
|
|
|
|
|
rndn1 = df.select('key', functions.randn(0)).collect()
|
|
|
|
rndn2 = df.select('key', functions.randn(0)).collect()
|
|
|
|
self.assertEqual(sorted(rndn1), sorted(rndn2))
|
|
|
|
|
2015-05-05 16:23:53 -04:00
|
|
|
def test_between_function(self):
|
|
|
|
df = self.sc.parallelize([
|
|
|
|
Row(a=1, b=2, c=3),
|
|
|
|
Row(a=2, b=1, c=3),
|
|
|
|
Row(a=4, b=1, c=4)]).toDF()
|
|
|
|
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
|
|
|
|
df.filter(df.a.between(df.b, df.c)).collect())
|
|
|
|
|
2015-06-29 17:15:15 -04:00
|
|
|
def test_struct_type(self):
|
|
|
|
from pyspark.sql.types import StructType, StringType, StructField
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True),
|
|
|
|
StructField("f2", StringType(), True, None)])
|
|
|
|
self.assertEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True)])
|
|
|
|
self.assertNotEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = (StructType().add(StructField("f1", StringType(), True))
|
|
|
|
.add(StructField("f2", StringType(), True, None)))
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True),
|
|
|
|
StructField("f2", StringType(), True, None)])
|
|
|
|
self.assertEqual(struct1, struct2)
|
|
|
|
|
|
|
|
struct1 = (StructType().add(StructField("f1", StringType(), True))
|
|
|
|
.add(StructField("f2", StringType(), True, None)))
|
|
|
|
struct2 = StructType([StructField("f1", StringType(), True)])
|
|
|
|
self.assertNotEqual(struct1, struct2)
|
|
|
|
|
|
|
|
# Catch exception raised during improper construction
|
|
|
|
try:
|
|
|
|
struct1 = StructType().add("name")
|
|
|
|
self.assertEqual(1, 0)
|
|
|
|
except ValueError:
|
|
|
|
self.assertEqual(1, 1)
|
|
|
|
|
2016-01-27 12:55:10 -05:00
|
|
|
def test_metadata_null(self):
|
|
|
|
from pyspark.sql.types import StructType, StringType, StructField
|
|
|
|
schema = StructType([StructField("f1", StringType(), True, None),
|
|
|
|
StructField("f2", StringType(), True, {'a': None})])
|
|
|
|
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
|
|
|
|
self.sqlCtx.createDataFrame(rdd, schema)
|
|
|
|
|
2015-02-10 20:29:52 -05:00
|
|
|
def test_save_and_load(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.json(tmpPath)
|
|
|
|
actual = self.sqlCtx.read.json(tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
2015-05-19 17:23:28 -04:00
|
|
|
actual = self.sqlCtx.read.json(tmpPath, schema)
|
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.json(tmpPath, "overwrite")
|
|
|
|
actual = self.sqlCtx.read.json(tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.save(format="json", mode="overwrite", path=tmpPath,
|
|
|
|
noUse="this options will not be used in save.")
|
|
|
|
actual = self.sqlCtx.read.load(format="json", path=tmpPath,
|
|
|
|
noUse="this options will not be used in load.")
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
|
|
|
|
"org.apache.spark.sql.parquet")
|
|
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
2016-01-04 21:02:38 -05:00
|
|
|
actual = self.sqlCtx.read.load(path=tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
2015-06-22 16:51:23 -04:00
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
|
|
|
def test_save_and_load_builder(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
df.write.json(tmpPath)
|
|
|
|
actual = self.sqlCtx.read.json(tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
|
|
|
actual = self.sqlCtx.read.json(tmpPath, schema)
|
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.mode("overwrite").json(tmpPath)
|
|
|
|
actual = self.sqlCtx.read.json(tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
|
2015-06-29 03:13:39 -04:00
|
|
|
.option("noUse", "this option will not be used in save.")\
|
2015-06-22 16:51:23 -04:00
|
|
|
.format("json").save(path=tmpPath)
|
|
|
|
actual =\
|
|
|
|
self.sqlCtx.read.format("json")\
|
|
|
|
.load(path=tmpPath, noUse="this options will not be used in load.")
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
|
|
|
|
"org.apache.spark.sql.parquet")
|
|
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
2016-01-04 21:02:38 -05:00
|
|
|
actual = self.sqlCtx.read.load(path=tmpPath)
|
2015-06-22 16:51:23 -04:00
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
def test_help_command(self):
|
|
|
|
# Regression test for SPARK-5464
|
|
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
2016-01-04 21:02:38 -05:00
|
|
|
df = self.sqlCtx.read.json(rdd)
|
2015-02-03 19:01:56 -05:00
|
|
|
# render_doc() reproduces the help() exception without printing output
|
|
|
|
pydoc.render_doc(df)
|
|
|
|
pydoc.render_doc(df.foo)
|
|
|
|
pydoc.render_doc(df.take(1))
|
|
|
|
|
2015-04-16 20:33:57 -04:00
|
|
|
def test_access_column(self):
|
|
|
|
df = self.df
|
|
|
|
self.assertTrue(isinstance(df.key, Column))
|
|
|
|
self.assertTrue(isinstance(df['key'], Column))
|
|
|
|
self.assertTrue(isinstance(df[0], Column))
|
|
|
|
self.assertRaises(IndexError, lambda: df[2])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertRaises(AnalysisException, lambda: df["bad_key"])
|
2015-04-16 20:33:57 -04:00
|
|
|
self.assertRaises(TypeError, lambda: df[{}])
|
|
|
|
|
2015-07-01 19:43:18 -04:00
|
|
|
def test_column_name_with_non_ascii(self):
|
|
|
|
df = self.sqlCtx.createDataFrame([(1,)], ["数量"])
|
|
|
|
self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema)
|
|
|
|
self.assertEqual("DataFrame[数量: bigint]", str(df))
|
|
|
|
self.assertEqual([("数量", 'bigint')], df.dtypes)
|
|
|
|
self.assertEqual(1, df.select("数量").first()[0])
|
|
|
|
self.assertEqual(1, df.select(df["数量"]).first()[0])
|
|
|
|
|
2015-04-16 20:33:57 -04:00
|
|
|
def test_access_nested_types(self):
|
|
|
|
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
|
|
|
|
self.assertEqual(1, df.select(df.l[0]).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.r.a).first()[0])
|
|
|
|
self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
|
|
|
|
self.assertEqual("v", df.select(df.d["k"]).first()[0])
|
|
|
|
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
|
|
|
|
|
2015-05-08 14:49:38 -04:00
|
|
|
def test_field_accessor(self):
|
|
|
|
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
|
|
|
|
self.assertEqual(1, df.select(df.l[0]).first()[0])
|
|
|
|
self.assertEqual(1, df.select(df.r["a"]).first()[0])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertEqual(1, df.select(df["r.a"]).first()[0])
|
2015-05-08 14:49:38 -04:00
|
|
|
self.assertEqual("b", df.select(df.r["b"]).first()[0])
|
2015-08-14 17:09:46 -04:00
|
|
|
self.assertEqual("b", df.select(df["r.b"]).first()[0])
|
2015-05-08 14:49:38 -04:00
|
|
|
self.assertEqual("v", df.select(df.d["k"]).first()[0])
|
|
|
|
|
2015-02-18 17:17:04 -05:00
|
|
|
def test_infer_long_type(self):
|
|
|
|
longrow = [Row(f1='a', f2=100000000000000)]
|
|
|
|
df = self.sc.parallelize(longrow).toDF()
|
|
|
|
self.assertEqual(df.schema.fields[1].dataType, LongType())
|
|
|
|
|
|
|
|
# this saving as Parquet caused issues as well.
|
|
|
|
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
|
2016-01-04 21:02:38 -05:00
|
|
|
df.write.parquet(output_dir)
|
|
|
|
df1 = self.sqlCtx.read.parquet(output_dir)
|
2015-09-18 12:53:52 -04:00
|
|
|
self.assertEqual('a', df1.first().f1)
|
|
|
|
self.assertEqual(100000000000000, df1.first().f2)
|
2015-02-18 17:17:04 -05:00
|
|
|
|
|
|
|
self.assertEqual(_infer_type(1), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**10), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**20), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**31 - 1), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**31), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**61), LongType())
|
|
|
|
self.assertEqual(_infer_type(2**71), LongType())
|
|
|
|
|
2015-04-21 03:08:18 -04:00
|
|
|
def test_filter_with_datetime(self):
|
|
|
|
time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
|
|
|
|
date = time.date()
|
|
|
|
row = Row(date=date, time=time)
|
|
|
|
df = self.sqlCtx.createDataFrame([row])
|
|
|
|
self.assertEqual(1, df.filter(df.date == date).count())
|
|
|
|
self.assertEqual(1, df.filter(df.time == time).count())
|
|
|
|
self.assertEqual(0, df.filter(df.date > date).count())
|
|
|
|
self.assertEqual(0, df.filter(df.time > time).count())
|
|
|
|
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
def test_filter_with_datetime_timezone(self):
|
|
|
|
dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
|
|
|
|
dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
|
|
|
|
row = Row(date=dt1)
|
|
|
|
df = self.sqlCtx.createDataFrame([row])
|
|
|
|
self.assertEqual(0, df.filter(df.date == dt2).count())
|
|
|
|
self.assertEqual(1, df.filter(df.date > dt2).count())
|
|
|
|
self.assertEqual(0, df.filter(df.date < dt2).count())
|
|
|
|
|
2015-06-11 04:00:41 -04:00
|
|
|
def test_time_with_timezone(self):
|
|
|
|
day = datetime.date.today()
|
|
|
|
now = datetime.datetime.now()
|
2015-07-10 16:05:23 -04:00
|
|
|
ts = time.mktime(now.timetuple())
|
2015-06-11 04:00:41 -04:00
|
|
|
# class in __main__ is not serializable
|
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function
This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162)
The issue is with DataFrame filter() function, if datetime.datetime is passed to it:
* Timezone information of this datetime is ignored
* This datetime is assumed to be in local timezone, which depends on the OS timezone setting
Fix includes both code change and regression test. Problem reproduction code on master:
```python
import pytz
from datetime import datetime
from pyspark.sql import *
from pyspark.sql.types import *
sqc = SQLContext(sc)
df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())]))
m1 = pytz.timezone('UTC')
m2 = pytz.timezone('Etc/GMT+3')
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
```
It gives the same timestamp ignoring time zone:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946713600000000)
Scan PhysicalRDD[dt#0]
```
After the fix:
```
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain()
Filter (dt#0 > 946684800000000)
Scan PhysicalRDD[dt#0]
>>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain()
Filter (dt#0 > 946695600000000)
Scan PhysicalRDD[dt#0]
```
PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo
Author: 0x0FFF <programmerag@gmail.com>
Closes #8555 from 0x0FFF/SPARK-10162.
2015-09-01 17:34:59 -04:00
|
|
|
from pyspark.sql.tests import UTCOffsetTimezone
|
|
|
|
utc = UTCOffsetTimezone()
|
2015-07-10 16:05:23 -04:00
|
|
|
utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds
|
2015-07-10 20:44:21 -04:00
|
|
|
# add microseconds to utcnow (keeping year,month,day,hour,minute,second)
|
2015-07-10 16:05:23 -04:00
|
|
|
utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
|
2015-06-11 04:00:41 -04:00
|
|
|
df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
|
|
|
|
day1, now1, utcnow1 = df.first()
|
2015-07-09 17:43:38 -04:00
|
|
|
self.assertEqual(day1, day)
|
|
|
|
self.assertEqual(now, now1)
|
|
|
|
self.assertEqual(now, utcnow1)
|
2015-06-11 04:00:41 -04:00
|
|
|
|
2015-07-08 21:22:53 -04:00
|
|
|
def test_decimal(self):
|
|
|
|
from decimal import Decimal
|
|
|
|
schema = StructType([StructField("decimal", DecimalType(10, 5))])
|
|
|
|
df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
|
|
|
|
row = df.select(df.decimal + 1).first()
|
|
|
|
self.assertEqual(row[0], Decimal("4.14159"))
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
df.write.parquet(tmpPath)
|
|
|
|
df2 = self.sqlCtx.read.parquet(tmpPath)
|
|
|
|
row = df2.first()
|
|
|
|
self.assertEqual(row[0], Decimal("3.14159"))
|
|
|
|
|
2015-03-30 23:47:10 -04:00
|
|
|
def test_dropna(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
# shouldn't drop a non-null row
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 50, 80.1)], schema).dropna().count(),
|
|
|
|
1)
|
|
|
|
|
|
|
|
# dropping rows with a single null value
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', None, 80.1)], schema).dropna().count(),
|
|
|
|
0)
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# if how = 'all', only drop rows if all values are null
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
|
|
|
|
1)
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(None, None, None)], schema).dropna(how='all').count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# how and subset
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
|
|
|
|
1)
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# threshold
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
|
|
|
|
1)
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', None, None)], schema).dropna(thresh=2).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# threshold and subset
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
|
|
|
|
1)
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
|
|
|
|
0)
|
|
|
|
|
|
|
|
# thresh should take precedence over how
|
|
|
|
self.assertEqual(self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 50, None)], schema).dropna(
|
|
|
|
how='any', thresh=2, subset=['name', 'age']).count(),
|
|
|
|
1)
|
|
|
|
|
|
|
|
def test_fillna(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
# fillna shouldn't change non-null values
|
|
|
|
row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# fillna with int
|
|
|
|
row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
|
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, 50.0)
|
|
|
|
|
|
|
|
# fillna with double
|
|
|
|
row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
|
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, 50.1)
|
|
|
|
|
|
|
|
# fillna with string
|
|
|
|
row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first()
|
|
|
|
self.assertEqual(row.name, u"hello")
|
|
|
|
self.assertEqual(row.age, None)
|
|
|
|
|
|
|
|
# fillna with subset specified for numeric cols
|
|
|
|
row = self.sqlCtx.createDataFrame(
|
|
|
|
[(None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
|
|
|
|
self.assertEqual(row.name, None)
|
|
|
|
self.assertEqual(row.age, 50)
|
|
|
|
self.assertEqual(row.height, None)
|
|
|
|
|
|
|
|
# fillna with subset specified for numeric cols
|
|
|
|
row = self.sqlCtx.createDataFrame(
|
|
|
|
[(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
|
|
|
|
self.assertEqual(row.name, "haha")
|
|
|
|
self.assertEqual(row.age, None)
|
|
|
|
self.assertEqual(row.height, None)
|
|
|
|
|
2015-05-07 04:00:29 -04:00
|
|
|
def test_bitwise_operations(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
row = Row(a=170, b=75)
|
|
|
|
df = self.sqlCtx.createDataFrame([row])
|
|
|
|
result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 & 75, result['(a & b)'])
|
|
|
|
result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 | 75, result['(a | b)'])
|
|
|
|
result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(170 ^ 75, result['(a ^ b)'])
|
|
|
|
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
|
|
|
self.assertEqual(~75, result['~b'])
|
|
|
|
|
2015-07-25 03:34:59 -04:00
|
|
|
def test_expr(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
row = Row(a="length string", b=75)
|
|
|
|
df = self.sqlCtx.createDataFrame([row])
|
|
|
|
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
|
2015-11-10 14:06:29 -05:00
|
|
|
self.assertEqual(13, result["length(a)"])
|
2015-07-25 03:34:59 -04:00
|
|
|
|
2015-05-12 13:23:41 -04:00
|
|
|
def test_replace(self):
|
|
|
|
schema = StructType([
|
|
|
|
StructField("name", StringType(), True),
|
|
|
|
StructField("age", IntegerType(), True),
|
|
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
|
|
|
|
# replace with int
|
|
|
|
row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
|
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
self.assertEqual(row.height, 20.0)
|
|
|
|
|
|
|
|
# replace with double
|
|
|
|
row = self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
|
|
|
|
self.assertEqual(row.age, 82)
|
|
|
|
self.assertEqual(row.height, 82.1)
|
|
|
|
|
|
|
|
# replace with string
|
|
|
|
row = self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
|
|
|
|
self.assertEqual(row.name, u"Ann")
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# replace with subset specified by a string of a column name w/ actual change
|
|
|
|
row = self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
|
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
|
|
|
|
# replace with subset specified by a string of a column name w/o actual change
|
|
|
|
row = self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
|
|
|
|
# replace with subset specified with one column replaced, another column not in subset
|
|
|
|
# stays unchanged.
|
|
|
|
row = self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
|
|
|
|
self.assertEqual(row.name, u'Alice')
|
|
|
|
self.assertEqual(row.age, 20)
|
|
|
|
self.assertEqual(row.height, 10.0)
|
|
|
|
|
|
|
|
# replace with subset specified but no column will be replaced
|
|
|
|
row = self.sqlCtx.createDataFrame(
|
|
|
|
[(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
|
|
|
|
self.assertEqual(row.name, u'Alice')
|
|
|
|
self.assertEqual(row.age, 10)
|
|
|
|
self.assertEqual(row.height, None)
|
|
|
|
|
2015-06-30 19:17:46 -04:00
|
|
|
def test_capture_analysis_exception(self):
|
|
|
|
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
|
|
|
|
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
|
2016-03-28 15:31:12 -04:00
|
|
|
|
|
|
|
def test_capture_parse_exception(self):
|
|
|
|
self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc"))
|
2015-06-30 19:17:46 -04:00
|
|
|
|
2015-07-19 03:32:56 -04:00
|
|
|
def test_capture_illegalargument_exception(self):
|
|
|
|
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
|
|
|
|
lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1"))
|
|
|
|
df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"])
|
|
|
|
self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
|
|
|
|
lambda: df.select(sha2(df.a, 1024)).collect())
|
2015-10-29 00:45:00 -04:00
|
|
|
try:
|
|
|
|
df.select(sha2(df.a, 1024)).collect()
|
|
|
|
except IllegalArgumentException as e:
|
|
|
|
self.assertRegexpMatches(e.desc, "1024 is not in the permitted values")
|
|
|
|
self.assertRegexpMatches(e.stackTrace,
|
|
|
|
"org.apache.spark.sql.functions")
|
2015-07-19 03:32:56 -04:00
|
|
|
|
2015-08-19 16:56:40 -04:00
|
|
|
def test_with_column_with_existing_name(self):
|
|
|
|
keys = self.df.withColumn("key", self.df.key).select("key").collect()
|
|
|
|
self.assertEqual([r.key for r in keys], list(range(100)))
|
|
|
|
|
2015-09-02 16:36:36 -04:00
|
|
|
# regression test for SPARK-10417
|
|
|
|
def test_column_iterator(self):
|
|
|
|
|
|
|
|
def foo():
|
|
|
|
for x in self.df.key:
|
|
|
|
break
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, foo)
|
|
|
|
|
2015-09-22 02:36:41 -04:00
|
|
|
# add test for SPARK-10577 (test broadcast join hint)
|
|
|
|
def test_functions_broadcast(self):
|
|
|
|
from pyspark.sql.functions import broadcast
|
|
|
|
|
|
|
|
df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
|
|
|
|
df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
|
|
|
|
|
|
|
|
# equijoin - should be converted into broadcast join
|
|
|
|
plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
|
|
|
|
self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
|
|
|
|
|
|
|
|
# no join key -- should not be a broadcast join
|
|
|
|
plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan()
|
|
|
|
self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
|
|
|
|
|
|
|
|
# planner should not crash without a join
|
|
|
|
broadcast(df1)._jdf.queryExecution().executedPlan()
|
|
|
|
|
2016-03-08 17:00:03 -05:00
|
|
|
def test_toDF_with_schema_string(self):
|
|
|
|
data = [Row(key=i, value=str(i)) for i in range(100)]
|
|
|
|
rdd = self.sc.parallelize(data, 5)
|
|
|
|
|
|
|
|
df = rdd.toDF("key: int, value: string")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
|
|
|
|
self.assertEqual(df.collect(), data)
|
|
|
|
|
|
|
|
# different but compatible field types can be used.
|
|
|
|
df = rdd.toDF("key: string, value: string")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
|
|
|
|
|
|
|
|
# field names can differ.
|
|
|
|
df = rdd.toDF(" a: int, b: string ")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
|
|
|
|
self.assertEqual(df.collect(), data)
|
|
|
|
|
|
|
|
# number of fields must match.
|
|
|
|
self.assertRaisesRegexp(Exception, "Length of object",
|
|
|
|
lambda: rdd.toDF("key: int").collect())
|
|
|
|
|
|
|
|
# field types mismatch will cause exception at runtime.
|
|
|
|
self.assertRaisesRegexp(Exception, "FloatType can not accept",
|
|
|
|
lambda: rdd.toDF("key: float, value: string").collect())
|
|
|
|
|
|
|
|
# flat schema values will be wrapped into row.
|
|
|
|
df = rdd.map(lambda row: row.key).toDF("int")
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
|
|
|
|
|
|
|
|
# users can use DataType directly instead of data type string.
|
|
|
|
df = rdd.map(lambda row: row.key).toDF(IntegerType())
|
|
|
|
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
|
|
|
|
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
|
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
|
2015-02-10 20:29:52 -05:00
|
|
|
class HiveContextSQLTests(ReusedPySparkTestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ReusedPySparkTestCase.setUpClass()
|
|
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
2015-02-17 18:44:37 -05:00
|
|
|
try:
|
|
|
|
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
|
|
|
|
except py4j.protocol.Py4JError:
|
2015-06-03 03:23:34 -04:00
|
|
|
cls.tearDownClass()
|
2015-05-23 11:30:05 -04:00
|
|
|
raise unittest.SkipTest("Hive is not available")
|
2015-04-16 19:20:57 -04:00
|
|
|
except TypeError:
|
2015-06-03 03:23:34 -04:00
|
|
|
cls.tearDownClass()
|
2015-05-23 11:30:05 -04:00
|
|
|
raise unittest.SkipTest("Hive is not available")
|
2015-02-10 20:29:52 -05:00
|
|
|
os.unlink(cls.tempdir.name)
|
|
|
|
_scala_HiveContext =\
|
|
|
|
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
|
|
|
|
cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
|
|
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
2015-02-14 02:03:22 -05:00
|
|
|
cls.df = cls.sc.parallelize(cls.testData).toDF()
|
2015-02-10 20:29:52 -05:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ReusedPySparkTestCase.tearDownClass()
|
|
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
|
|
|
|
def test_save_and_load_table(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
|
|
|
|
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json")
|
|
|
|
self.assertEqual(sorted(df.collect()),
|
|
|
|
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
|
|
|
self.assertEqual(sorted(df.collect()),
|
|
|
|
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
|
|
|
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
|
2015-02-10 20:29:52 -05:00
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
2015-05-19 17:23:28 -04:00
|
|
|
actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json",
|
2015-02-10 20:29:52 -05:00
|
|
|
schema=schema, path=tmpPath,
|
|
|
|
noUse="this options will not be used")
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
|
|
|
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
|
|
|
self.assertEqual(sorted(df.select("value").collect()),
|
|
|
|
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
self.sqlCtx.sql("DROP TABLE savedJsonTable")
|
|
|
|
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
|
|
|
|
|
|
|
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
|
|
|
|
"org.apache.spark.sql.parquet")
|
|
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
2015-05-19 17:23:28 -04:00
|
|
|
df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
|
2015-02-10 20:29:52 -05:00
|
|
|
actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
|
2015-05-19 17:23:28 -04:00
|
|
|
self.assertEqual(sorted(df.collect()),
|
|
|
|
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
|
|
|
self.assertEqual(sorted(df.collect()),
|
|
|
|
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
2015-02-10 20:29:52 -05:00
|
|
|
self.sqlCtx.sql("DROP TABLE savedJsonTable")
|
|
|
|
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
|
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
2015-05-23 11:30:05 -04:00
|
|
|
def test_window_functions(self):
|
|
|
|
df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
w = Window.partitionBy("value").orderBy("key")
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
sel = df.select(df.value, df.key,
|
|
|
|
F.max("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.min("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.row_number().over(w),
|
2015-05-23 11:30:05 -04:00
|
|
|
F.rank().over(w),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.dense_rank().over(w),
|
2015-05-23 11:30:05 -04:00
|
|
|
F.ntile(2).over(w))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [
|
|
|
|
("1", 1, 1, 1, 1, 1, 1, 1, 1),
|
|
|
|
("2", 1, 1, 1, 3, 1, 1, 1, 1),
|
|
|
|
("2", 1, 2, 1, 3, 2, 1, 1, 1),
|
|
|
|
("2", 2, 2, 2, 3, 3, 3, 2, 2)
|
|
|
|
]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
2015-05-19 17:23:28 -04:00
|
|
|
|
2015-08-14 16:55:29 -04:00
|
|
|
def test_window_functions_without_partitionBy(self):
|
|
|
|
df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
w = Window.orderBy("key", df.value)
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
sel = df.select(df.value, df.key,
|
|
|
|
F.max("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.min("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.row_number().over(w),
|
2015-08-14 16:55:29 -04:00
|
|
|
F.rank().over(w),
|
2016-01-04 21:02:38 -05:00
|
|
|
F.dense_rank().over(w),
|
2015-08-14 16:55:29 -04:00
|
|
|
F.ntile(2).over(w))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [
|
|
|
|
("1", 1, 1, 1, 4, 1, 1, 1, 1),
|
|
|
|
("2", 1, 1, 1, 4, 2, 2, 2, 1),
|
|
|
|
("2", 1, 2, 1, 4, 3, 2, 2, 2),
|
|
|
|
("2", 2, 2, 2, 4, 4, 4, 3, 2)
|
|
|
|
]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
2015-11-09 17:30:37 -05:00
|
|
|
def test_collect_functions(self):
|
|
|
|
df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
from pyspark.sql import functions
|
|
|
|
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
|
|
|
|
[1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
|
|
|
|
[1, 1, 1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
|
|
|
|
["1", "2"])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
|
|
|
|
["1", "2", "2", "2"])
|
|
|
|
|
2015-08-14 16:55:29 -04:00
|
|
|
|
2015-02-03 19:01:56 -05:00
|
|
|
if __name__ == "__main__":
|
2016-01-20 14:11:10 -05:00
|
|
|
from pyspark.sql.tests import *
|
2015-10-22 18:27:11 -04:00
|
|
|
if xmlrunner:
|
|
|
|
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
|
|
|
|
else:
|
|
|
|
unittest.main()
|