e0e64ba4b1
The _eq_ of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released. Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython). This PR also improve the performance of inferSchema (avoid the unnecessary converter of object). cc pwendell JoshRosen Author: Davies Liu <davies@databricks.com> Closes #4808 from davies/leak and squashes the following commits: 6a322a4 [Davies Liu] tests refactor 3da44fc [Davies Liu] fix __eq__ of Singleton 534ac90 [Davies Liu] add more checks 46999dc [Davies Liu] fix tests d9ae973 [Davies Liu] fix memory leak in sql
497 lines
21 KiB
Python
497 lines
21 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""
|
|
Unit tests for pyspark.sql; additional tests are implemented as doctests in
|
|
individual modules.
|
|
"""
|
|
import os
|
|
import sys
|
|
import pydoc
|
|
import shutil
|
|
import tempfile
|
|
import pickle
|
|
|
|
import py4j
|
|
|
|
if sys.version_info[:2] <= (2, 6):
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
|
|
sys.exit(1)
|
|
else:
|
|
import unittest
|
|
|
|
from pyspark.sql import SQLContext, HiveContext, Column, Row
|
|
from pyspark.sql.types import *
|
|
from pyspark.sql.types import UserDefinedType, _infer_type
|
|
from pyspark.tests import ReusedPySparkTestCase
|
|
|
|
|
|
class ExamplePointUDT(UserDefinedType):
|
|
"""
|
|
User-defined type (UDT) for ExamplePoint.
|
|
"""
|
|
|
|
@classmethod
|
|
def sqlType(self):
|
|
return ArrayType(DoubleType(), False)
|
|
|
|
@classmethod
|
|
def module(cls):
|
|
return 'pyspark.tests'
|
|
|
|
@classmethod
|
|
def scalaUDT(cls):
|
|
return 'org.apache.spark.sql.test.ExamplePointUDT'
|
|
|
|
def serialize(self, obj):
|
|
return [obj.x, obj.y]
|
|
|
|
def deserialize(self, datum):
|
|
return ExamplePoint(datum[0], datum[1])
|
|
|
|
|
|
class ExamplePoint:
|
|
"""
|
|
An example class to demonstrate UDT in Scala, Java, and Python.
|
|
"""
|
|
|
|
__UDT__ = ExamplePointUDT()
|
|
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def __repr__(self):
|
|
return "ExamplePoint(%s,%s)" % (self.x, self.y)
|
|
|
|
def __str__(self):
|
|
return "(%s,%s)" % (self.x, self.y)
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, ExamplePoint) and \
|
|
other.x == self.x and other.y == self.y
|
|
|
|
|
|
class DataTypeTests(unittest.TestCase):
|
|
# regression test for SPARK-6055
|
|
def test_data_type_eq(self):
|
|
lt = LongType()
|
|
lt2 = pickle.loads(pickle.dumps(LongType()))
|
|
self.assertEquals(lt, lt2)
|
|
|
|
|
|
class SQLTests(ReusedPySparkTestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ReusedPySparkTestCase.setUpClass()
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
|
os.unlink(cls.tempdir.name)
|
|
cls.sqlCtx = SQLContext(cls.sc)
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
|
rdd = cls.sc.parallelize(cls.testData)
|
|
cls.df = rdd.toDF()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ReusedPySparkTestCase.tearDownClass()
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
def test_udf(self):
|
|
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
|
|
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
|
|
self.assertEqual(row[0], 5)
|
|
|
|
def test_udf2(self):
|
|
self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
|
|
self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
|
|
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
|
|
self.assertEqual(4, res[0])
|
|
|
|
def test_udf_with_array_type(self):
|
|
d = [Row(l=range(3), d={"key": range(5)})]
|
|
rdd = self.sc.parallelize(d)
|
|
self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
|
|
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
|
|
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
|
|
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
|
|
self.assertEqual(range(3), l1)
|
|
self.assertEqual(1, l2)
|
|
|
|
def test_broadcast_in_udf(self):
|
|
bar = {"a": "aa", "b": "bb", "c": "abc"}
|
|
foo = self.sc.broadcast(bar)
|
|
self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
|
|
[res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
|
|
self.assertEqual("abc", res[0])
|
|
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
|
|
self.assertEqual("", res[0])
|
|
|
|
def test_basic_functions(self):
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
|
df = self.sqlCtx.jsonRDD(rdd)
|
|
df.count()
|
|
df.collect()
|
|
df.schema
|
|
|
|
# cache and checkpoint
|
|
self.assertFalse(df.is_cached)
|
|
df.persist()
|
|
df.unpersist()
|
|
df.cache()
|
|
self.assertTrue(df.is_cached)
|
|
self.assertEqual(2, df.count())
|
|
|
|
df.registerTempTable("temp")
|
|
df = self.sqlCtx.sql("select foo from temp")
|
|
df.count()
|
|
df.collect()
|
|
|
|
def test_apply_schema_to_row(self):
|
|
df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
|
|
df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema)
|
|
self.assertEqual(df.collect(), df2.collect())
|
|
|
|
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
|
|
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
|
|
self.assertEqual(10, df3.count())
|
|
|
|
def test_serialize_nested_array_and_map(self):
|
|
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
|
|
rdd = self.sc.parallelize(d)
|
|
df = self.sqlCtx.createDataFrame(rdd)
|
|
row = df.head()
|
|
self.assertEqual(1, len(row.l))
|
|
self.assertEqual(1, row.l[0].a)
|
|
self.assertEqual("2", row.d["key"].d)
|
|
|
|
l = df.map(lambda x: x.l).first()
|
|
self.assertEqual(1, len(l))
|
|
self.assertEqual('s', l[0].b)
|
|
|
|
d = df.map(lambda x: x.d).first()
|
|
self.assertEqual(1, len(d))
|
|
self.assertEqual(1.0, d["key"].c)
|
|
|
|
row = df.map(lambda x: x.d["key"]).first()
|
|
self.assertEqual(1.0, row.c)
|
|
self.assertEqual("2", row.d)
|
|
|
|
def test_infer_schema(self):
|
|
d = [Row(l=[], d={}, s=None),
|
|
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
|
|
rdd = self.sc.parallelize(d)
|
|
df = self.sqlCtx.createDataFrame(rdd)
|
|
self.assertEqual([], df.map(lambda r: r.l).first())
|
|
self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
|
|
df.registerTempTable("test")
|
|
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
|
|
self.assertEqual(1, result.head()[0])
|
|
|
|
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
|
|
self.assertEqual(df.schema, df2.schema)
|
|
self.assertEqual({}, df2.map(lambda r: r.d).first())
|
|
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
|
|
df2.registerTempTable("test2")
|
|
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
|
|
self.assertEqual(1, result.head()[0])
|
|
|
|
def test_infer_nested_schema(self):
|
|
NestedRow = Row("f1", "f2")
|
|
nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
|
|
NestedRow([2, 3], {"row2": 2.0})])
|
|
df = self.sqlCtx.inferSchema(nestedRdd1)
|
|
self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
|
|
|
|
nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
|
|
NestedRow([[2, 3], [3, 4]], [2, 3])])
|
|
df = self.sqlCtx.inferSchema(nestedRdd2)
|
|
self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
|
|
|
|
from collections import namedtuple
|
|
CustomRow = namedtuple('CustomRow', 'field1 field2')
|
|
rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
|
|
CustomRow(field1=2, field2="row2"),
|
|
CustomRow(field1=3, field2="row3")])
|
|
df = self.sqlCtx.inferSchema(rdd)
|
|
self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
|
|
|
|
def test_apply_schema(self):
|
|
from datetime import date, datetime
|
|
rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
|
|
date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
|
|
{"a": 1}, (2,), [1, 2, 3], None)])
|
|
schema = StructType([
|
|
StructField("byte1", ByteType(), False),
|
|
StructField("byte2", ByteType(), False),
|
|
StructField("short1", ShortType(), False),
|
|
StructField("short2", ShortType(), False),
|
|
StructField("int1", IntegerType(), False),
|
|
StructField("float1", FloatType(), False),
|
|
StructField("date1", DateType(), False),
|
|
StructField("time1", TimestampType(), False),
|
|
StructField("map1", MapType(StringType(), IntegerType(), False), False),
|
|
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
|
|
StructField("list1", ArrayType(ByteType(), False), False),
|
|
StructField("null1", DoubleType(), True)])
|
|
df = self.sqlCtx.applySchema(rdd, schema)
|
|
results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
|
|
x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
|
|
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
|
|
datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
|
|
self.assertEqual(r, results.first())
|
|
|
|
df.registerTempTable("table2")
|
|
r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
|
|
"short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
|
|
"float1 + 1.5 as float1 FROM table2").first()
|
|
|
|
self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
|
|
|
|
from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
|
|
rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
|
|
{"a": 1}, (2,), [1, 2, 3])])
|
|
abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
|
|
schema = _parse_schema_abstract(abstract)
|
|
typedSchema = _infer_schema_type(rdd.first(), schema)
|
|
df = self.sqlCtx.applySchema(rdd, typedSchema)
|
|
r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
|
|
self.assertEqual(r, tuple(df.first()))
|
|
|
|
def test_struct_in_map(self):
|
|
d = [Row(m={Row(i=1): Row(s="")})]
|
|
df = self.sc.parallelize(d).toDF()
|
|
k, v = df.head().m.items()[0]
|
|
self.assertEqual(1, k.i)
|
|
self.assertEqual("", v.s)
|
|
|
|
def test_convert_row_to_dict(self):
|
|
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
|
|
self.assertEqual(1, row.asDict()['l'][0].a)
|
|
df = self.sc.parallelize([row]).toDF()
|
|
df.registerTempTable("test")
|
|
row = self.sqlCtx.sql("select l, d from test").head()
|
|
self.assertEqual(1, row.asDict()["l"][0].a)
|
|
self.assertEqual(1.0, row.asDict()['d']['key'].c)
|
|
|
|
def test_infer_schema_with_udt(self):
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
|
df = self.sc.parallelize([row]).toDF()
|
|
schema = df.schema
|
|
field = [f for f in schema.fields if f.name == "point"][0]
|
|
self.assertEqual(type(field.dataType), ExamplePointUDT)
|
|
df.registerTempTable("labeled_point")
|
|
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
|
|
self.assertEqual(point, ExamplePoint(1.0, 2.0))
|
|
|
|
def test_apply_schema_with_udt(self):
|
|
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
|
|
row = (1.0, ExamplePoint(1.0, 2.0))
|
|
rdd = self.sc.parallelize([row])
|
|
schema = StructType([StructField("label", DoubleType(), False),
|
|
StructField("point", ExamplePointUDT(), False)])
|
|
df = rdd.toDF(schema)
|
|
point = df.head().point
|
|
self.assertEquals(point, ExamplePoint(1.0, 2.0))
|
|
|
|
def test_parquet_with_udt(self):
|
|
from pyspark.sql.tests import ExamplePoint
|
|
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
|
|
df0 = self.sc.parallelize([row]).toDF()
|
|
output_dir = os.path.join(self.tempdir.name, "labeled_point")
|
|
df0.saveAsParquetFile(output_dir)
|
|
df1 = self.sqlCtx.parquetFile(output_dir)
|
|
point = df1.head().point
|
|
self.assertEquals(point, ExamplePoint(1.0, 2.0))
|
|
|
|
def test_column_operators(self):
|
|
ci = self.df.key
|
|
cs = self.df.value
|
|
c = ci == cs
|
|
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
|
|
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
|
|
self.assertTrue(all(isinstance(c, Column) for c in rcc))
|
|
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
|
|
self.assertTrue(all(isinstance(c, Column) for c in cb))
|
|
cbool = (ci & ci), (ci | ci), (~ci)
|
|
self.assertTrue(all(isinstance(c, Column) for c in cbool))
|
|
css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
|
|
self.assertTrue(all(isinstance(c, Column) for c in css))
|
|
self.assertTrue(isinstance(ci.cast(LongType()), Column))
|
|
|
|
def test_column_select(self):
|
|
df = self.df
|
|
self.assertEqual(self.testData, df.select("*").collect())
|
|
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
|
|
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
|
|
|
|
def test_aggregator(self):
|
|
df = self.df
|
|
g = df.groupBy()
|
|
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
|
|
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
|
|
|
|
from pyspark.sql import functions
|
|
self.assertEqual((0, u'99'),
|
|
tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
|
|
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
|
|
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
|
|
|
|
def test_save_and_load(self):
|
|
df = self.df
|
|
tmpPath = tempfile.mkdtemp()
|
|
shutil.rmtree(tmpPath)
|
|
df.save(tmpPath, "org.apache.spark.sql.json", "error")
|
|
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
|
|
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
|
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
|
|
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
|
|
|
|
df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
|
|
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
|
|
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
|
|
|
df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
|
|
noUse="this options will not be used in save.")
|
|
actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
|
|
noUse="this options will not be used in load.")
|
|
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
|
|
|
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
|
|
"org.apache.spark.sql.parquet")
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
|
actual = self.sqlCtx.load(path=tmpPath)
|
|
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
def test_help_command(self):
|
|
# Regression test for SPARK-5464
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
|
df = self.sqlCtx.jsonRDD(rdd)
|
|
# render_doc() reproduces the help() exception without printing output
|
|
pydoc.render_doc(df)
|
|
pydoc.render_doc(df.foo)
|
|
pydoc.render_doc(df.take(1))
|
|
|
|
def test_infer_long_type(self):
|
|
longrow = [Row(f1='a', f2=100000000000000)]
|
|
df = self.sc.parallelize(longrow).toDF()
|
|
self.assertEqual(df.schema.fields[1].dataType, LongType())
|
|
|
|
# this saving as Parquet caused issues as well.
|
|
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
|
|
df.saveAsParquetFile(output_dir)
|
|
df1 = self.sqlCtx.parquetFile(output_dir)
|
|
self.assertEquals('a', df1.first().f1)
|
|
self.assertEquals(100000000000000, df1.first().f2)
|
|
|
|
self.assertEqual(_infer_type(1), LongType())
|
|
self.assertEqual(_infer_type(2**10), LongType())
|
|
self.assertEqual(_infer_type(2**20), LongType())
|
|
self.assertEqual(_infer_type(2**31 - 1), LongType())
|
|
self.assertEqual(_infer_type(2**31), LongType())
|
|
self.assertEqual(_infer_type(2**61), LongType())
|
|
self.assertEqual(_infer_type(2**71), LongType())
|
|
|
|
|
|
class HiveContextSQLTests(ReusedPySparkTestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ReusedPySparkTestCase.setUpClass()
|
|
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
|
|
try:
|
|
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
|
|
except py4j.protocol.Py4JError:
|
|
cls.sqlCtx = None
|
|
return
|
|
os.unlink(cls.tempdir.name)
|
|
_scala_HiveContext =\
|
|
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
|
|
cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
|
|
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
|
|
cls.df = cls.sc.parallelize(cls.testData).toDF()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ReusedPySparkTestCase.tearDownClass()
|
|
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
|
|
|
|
def test_save_and_load_table(self):
|
|
if self.sqlCtx is None:
|
|
return # no hive available, skipped
|
|
|
|
df = self.df
|
|
tmpPath = tempfile.mkdtemp()
|
|
shutil.rmtree(tmpPath)
|
|
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
|
|
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
|
|
"org.apache.spark.sql.json")
|
|
self.assertTrue(
|
|
sorted(df.collect()) ==
|
|
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
|
self.assertTrue(
|
|
sorted(df.collect()) ==
|
|
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
|
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
|
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
|
|
|
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
|
|
schema = StructType([StructField("value", StringType(), True)])
|
|
actual = self.sqlCtx.createExternalTable("externalJsonTable",
|
|
source="org.apache.spark.sql.json",
|
|
schema=schema, path=tmpPath,
|
|
noUse="this options will not be used")
|
|
self.assertTrue(
|
|
sorted(df.collect()) ==
|
|
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
|
self.assertTrue(
|
|
sorted(df.select("value").collect()) ==
|
|
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
|
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
|
|
self.sqlCtx.sql("DROP TABLE savedJsonTable")
|
|
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
|
|
|
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
|
|
"org.apache.spark.sql.parquet")
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
|
df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
|
|
actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
|
|
self.assertTrue(
|
|
sorted(df.collect()) ==
|
|
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
|
self.assertTrue(
|
|
sorted(df.collect()) ==
|
|
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
|
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
|
self.sqlCtx.sql("DROP TABLE savedJsonTable")
|
|
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
|
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|