spark-instrumented-optimizer/python/pyspark/sql/tests.py
Davies Liu e0e64ba4b1 [SPARK-6055] [PySpark] fix incorrect __eq__ of DataType
The _eq_ of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released.

Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython).

This PR also improve the performance of inferSchema (avoid the unnecessary converter of object).

cc pwendell  JoshRosen

Author: Davies Liu <davies@databricks.com>

Closes #4808 from davies/leak and squashes the following commits:

6a322a4 [Davies Liu] tests refactor
3da44fc [Davies Liu] fix __eq__ of Singleton
534ac90 [Davies Liu] add more checks
46999dc [Davies Liu] fix tests
d9ae973 [Davies Liu] fix memory leak in sql
2015-02-27 20:07:17 -08:00

497 lines
21 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for pyspark.sql; additional tests are implemented as doctests in
individual modules.
"""
import os
import sys
import pydoc
import shutil
import tempfile
import pickle
import py4j
if sys.version_info[:2] <= (2, 6):
try:
import unittest2 as unittest
except ImportError:
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
sys.exit(1)
else:
import unittest
from pyspark.sql import SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
class ExamplePointUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
"""
@classmethod
def sqlType(self):
return ArrayType(DoubleType(), False)
@classmethod
def module(cls):
return 'pyspark.tests'
@classmethod
def scalaUDT(cls):
return 'org.apache.spark.sql.test.ExamplePointUDT'
def serialize(self, obj):
return [obj.x, obj.y]
def deserialize(self, datum):
return ExamplePoint(datum[0], datum[1])
class ExamplePoint:
"""
An example class to demonstrate UDT in Scala, Java, and Python.
"""
__UDT__ = ExamplePointUDT()
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return "ExamplePoint(%s,%s)" % (self.x, self.y)
def __str__(self):
return "(%s,%s)" % (self.x, self.y)
def __eq__(self, other):
return isinstance(other, ExamplePoint) and \
other.x == self.x and other.y == self.y
class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
lt = LongType()
lt2 = pickle.loads(pickle.dumps(LongType()))
self.assertEquals(lt, lt2)
class SQLTests(ReusedPySparkTestCase):
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
cls.sqlCtx = SQLContext(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
cls.df = rdd.toDF()
@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)
def test_udf2(self):
self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
rdd = self.sc.parallelize(d)
self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
self.assertEqual(range(3), l1)
self.assertEqual(1, l2)
def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
[res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
self.assertEqual("abc", res[0])
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.sqlCtx.jsonRDD(rdd)
df.count()
df.collect()
df.schema
# cache and checkpoint
self.assertFalse(df.is_cached)
df.persist()
df.unpersist()
df.cache()
self.assertTrue(df.is_cached)
self.assertEqual(2, df.count())
df.registerTempTable("temp")
df = self.sqlCtx.sql("select foo from temp")
df.count()
df.collect()
def test_apply_schema_to_row(self):
df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema)
self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
df = self.sqlCtx.createDataFrame(rdd)
row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
self.assertEqual("2", row.d["key"].d)
l = df.map(lambda x: x.l).first()
self.assertEqual(1, len(l))
self.assertEqual('s', l[0].b)
d = df.map(lambda x: x.d).first()
self.assertEqual(1, len(d))
self.assertEqual(1.0, d["key"].c)
row = df.map(lambda x: x.d["key"]).first()
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)
def test_infer_schema(self):
d = [Row(l=[], d={}, s=None),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
df = self.sqlCtx.createDataFrame(rdd)
self.assertEqual([], df.map(lambda r: r.l).first())
self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
df.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
self.assertEqual(df.schema, df2.schema)
self.assertEqual({}, df2.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
df2.registerTempTable("test2")
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
def test_infer_nested_schema(self):
NestedRow = Row("f1", "f2")
nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
NestedRow([2, 3], {"row2": 2.0})])
df = self.sqlCtx.inferSchema(nestedRdd1)
self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
NestedRow([[2, 3], [3, 4]], [2, 3])])
df = self.sqlCtx.inferSchema(nestedRdd2)
self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
from collections import namedtuple
CustomRow = namedtuple('CustomRow', 'field1 field2')
rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
CustomRow(field1=2, field2="row2"),
CustomRow(field1=3, field2="row3")])
df = self.sqlCtx.inferSchema(rdd)
self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
def test_apply_schema(self):
from datetime import date, datetime
rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
{"a": 1}, (2,), [1, 2, 3], None)])
schema = StructType([
StructField("byte1", ByteType(), False),
StructField("byte2", ByteType(), False),
StructField("short1", ShortType(), False),
StructField("short2", ShortType(), False),
StructField("int1", IntegerType(), False),
StructField("float1", FloatType(), False),
StructField("date1", DateType(), False),
StructField("time1", TimestampType(), False),
StructField("map1", MapType(StringType(), IntegerType(), False), False),
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
StructField("list1", ArrayType(ByteType(), False), False),
StructField("null1", DoubleType(), True)])
df = self.sqlCtx.applySchema(rdd, schema)
results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
self.assertEqual(r, results.first())
df.registerTempTable("table2")
r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
"short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
"float1 + 1.5 as float1 FROM table2").first()
self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
{"a": 1}, (2,), [1, 2, 3])])
abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
schema = _parse_schema_abstract(abstract)
typedSchema = _infer_schema_type(rdd.first(), schema)
df = self.sqlCtx.applySchema(rdd, typedSchema)
r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
self.assertEqual(r, tuple(df.first()))
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
df = self.sc.parallelize(d).toDF()
k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
df = self.sc.parallelize([row]).toDF()
df.registerTempTable("test")
row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
self.assertEqual(1.0, row.asDict()['d']['key'].c)
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.sc.parallelize([row]).toDF()
schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
df.registerTempTable("labeled_point")
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
def test_apply_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = (1.0, ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
df = rdd.toDF(schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df0 = self.sc.parallelize([row]).toDF()
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
point = df1.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_column_operators(self):
ci = self.df.key
cs = self.df.value
c = ci == cs
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
self.assertTrue(all(isinstance(c, Column) for c in rcc))
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
self.assertTrue(all(isinstance(c, Column) for c in cb))
cbool = (ci & ci), (ci | ci), (~ci)
self.assertTrue(all(isinstance(c, Column) for c in cbool))
css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
self.assertTrue(all(isinstance(c, Column) for c in css))
self.assertTrue(isinstance(ci.cast(LongType()), Column))
def test_column_select(self):
df = self.df
self.assertEqual(self.testData, df.select("*").collect())
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
def test_aggregator(self):
df = self.df
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
from pyspark.sql import functions
self.assertEqual((0, u'99'),
tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.save(tmpPath, "org.apache.spark.sql.json", "error")
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
schema = StructType([StructField("value", StringType(), True)])
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
noUse="this options will not be used in save.")
actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
noUse="this options will not be used in load.")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
actual = self.sqlCtx.load(path=tmpPath)
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
shutil.rmtree(tmpPath)
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.sqlCtx.jsonRDD(rdd)
# render_doc() reproduces the help() exception without printing output
pydoc.render_doc(df)
pydoc.render_doc(df.foo)
pydoc.render_doc(df.take(1))
def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
df = self.sc.parallelize(longrow).toDF()
self.assertEqual(df.schema.fields[1].dataType, LongType())
# this saving as Parquet caused issues as well.
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
df.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
self.assertEquals('a', df1.first().f1)
self.assertEquals(100000000000000, df1.first().f2)
self.assertEqual(_infer_type(1), LongType())
self.assertEqual(_infer_type(2**10), LongType())
self.assertEqual(_infer_type(2**20), LongType())
self.assertEqual(_infer_type(2**31 - 1), LongType())
self.assertEqual(_infer_type(2**31), LongType())
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())
class HiveContextSQLTests(ReusedPySparkTestCase):
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
try:
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
except py4j.protocol.Py4JError:
cls.sqlCtx = None
return
os.unlink(cls.tempdir.name)
_scala_HiveContext =\
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
cls.df = cls.sc.parallelize(cls.testData).toDF()
@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
def test_save_and_load_table(self):
if self.sqlCtx is None:
return # no hive available, skipped
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
"org.apache.spark.sql.json")
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE externalJsonTable")
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
schema = StructType([StructField("value", StringType(), True)])
actual = self.sqlCtx.createExternalTable("externalJsonTable",
source="org.apache.spark.sql.json",
schema=schema, path=tmpPath,
noUse="this options will not be used")
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.select("value").collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE savedJsonTable")
self.sqlCtx.sql("DROP TABLE externalJsonTable")
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE savedJsonTable")
self.sqlCtx.sql("DROP TABLE externalJsonTable")
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
shutil.rmtree(tmpPath)
if __name__ == "__main__":
unittest.main()