2018-11-14 01:51:11 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
import tempfile
|
|
|
|
|
2020-07-19 21:42:33 -04:00
|
|
|
from pyspark.sql.functions import col
|
|
|
|
from pyspark.sql.readwriter import DataFrameWriterV2
|
2018-11-14 01:51:11 -05:00
|
|
|
from pyspark.sql.types import *
|
|
|
|
from pyspark.testing.sqlutils import ReusedSQLTestCase
|
|
|
|
|
|
|
|
|
|
|
|
class ReadwriterTests(ReusedSQLTestCase):
|
|
|
|
|
|
|
|
def test_save_and_load(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
df.write.json(tmpPath)
|
|
|
|
actual = self.spark.read.json(tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
|
|
|
actual = self.spark.read.json(tmpPath, schema)
|
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.json(tmpPath, "overwrite")
|
|
|
|
actual = self.spark.read.json(tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.save(format="json", mode="overwrite", path=tmpPath,
|
|
|
|
noUse="this options will not be used in save.")
|
|
|
|
actual = self.spark.read.load(format="json", path=tmpPath,
|
|
|
|
noUse="this options will not be used in load.")
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
|
|
|
|
"org.apache.spark.sql.parquet")
|
|
|
|
self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
|
|
|
actual = self.spark.read.load(path=tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
|
|
|
|
|
|
|
csvpath = os.path.join(tempfile.mkdtemp(), 'data')
|
|
|
|
df.write.option('quote', None).format('csv').save(csvpath)
|
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
|
|
|
def test_save_and_load_builder(self):
|
|
|
|
df = self.df
|
|
|
|
tmpPath = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
df.write.json(tmpPath)
|
|
|
|
actual = self.spark.read.json(tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
schema = StructType([StructField("value", StringType(), True)])
|
|
|
|
actual = self.spark.read.json(tmpPath, schema)
|
|
|
|
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.mode("overwrite").json(tmpPath)
|
|
|
|
actual = self.spark.read.json(tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
|
|
|
|
.option("noUse", "this option will not be used in save.")\
|
|
|
|
.format("json").save(path=tmpPath)
|
|
|
|
actual =\
|
|
|
|
self.spark.read.format("json")\
|
|
|
|
.load(path=tmpPath, noUse="this options will not be used in load.")
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
|
|
|
|
defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
|
|
|
|
"org.apache.spark.sql.parquet")
|
|
|
|
self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
|
|
|
actual = self.spark.read.load(path=tmpPath)
|
|
|
|
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
|
|
|
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
|
|
|
|
|
|
|
shutil.rmtree(tmpPath)
|
|
|
|
|
|
|
|
def test_bucketed_write(self):
|
|
|
|
data = [
|
|
|
|
(1, "foo", 3.0), (2, "foo", 5.0),
|
|
|
|
(3, "bar", -1.0), (4, "bar", 6.0),
|
|
|
|
]
|
|
|
|
df = self.spark.createDataFrame(data, ["x", "y", "z"])
|
|
|
|
|
|
|
|
def count_bucketed_cols(names, table="pyspark_bucket"):
|
|
|
|
"""Given a sequence of column names and a table name
|
|
|
|
query the catalog and return number o columns which are
|
|
|
|
used for bucketing
|
|
|
|
"""
|
|
|
|
cols = self.spark.catalog.listColumns(table)
|
|
|
|
num = len([c for c in cols if c.name in names and c.isBucket])
|
|
|
|
return num
|
|
|
|
|
|
|
|
with self.table("pyspark_bucket"):
|
|
|
|
# Test write with one bucketing column
|
|
|
|
df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
|
|
|
|
self.assertEqual(count_bucketed_cols(["x"]), 1)
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write two bucketing columns
|
|
|
|
df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
|
|
|
|
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write with bucket and sort
|
|
|
|
df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
|
|
|
|
self.assertEqual(count_bucketed_cols(["x"]), 1)
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write with a list of columns
|
|
|
|
df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
|
|
|
|
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write with bucket and sort with a list of columns
|
|
|
|
(df.write.bucketBy(2, "x")
|
|
|
|
.sortBy(["y", "z"])
|
|
|
|
.mode("overwrite").saveAsTable("pyspark_bucket"))
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
|
|
|
# Test write with bucket and sort with multiple columns
|
|
|
|
(df.write.bucketBy(2, "x")
|
|
|
|
.sortBy("y", "z")
|
|
|
|
.mode("overwrite").saveAsTable("pyspark_bucket"))
|
|
|
|
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
|
|
|
|
|
2019-07-18 00:37:59 -04:00
|
|
|
def test_insert_into(self):
|
|
|
|
df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
|
|
|
|
with self.table("test_table"):
|
|
|
|
df.write.saveAsTable("test_table")
|
|
|
|
self.assertEqual(2, self.spark.sql("select * from test_table").count())
|
|
|
|
|
|
|
|
df.write.insertInto("test_table")
|
|
|
|
self.assertEqual(4, self.spark.sql("select * from test_table").count())
|
|
|
|
|
|
|
|
df.write.mode("overwrite").insertInto("test_table")
|
|
|
|
self.assertEqual(2, self.spark.sql("select * from test_table").count())
|
|
|
|
|
|
|
|
df.write.insertInto("test_table", True)
|
|
|
|
self.assertEqual(2, self.spark.sql("select * from test_table").count())
|
|
|
|
|
|
|
|
df.write.insertInto("test_table", False)
|
|
|
|
self.assertEqual(4, self.spark.sql("select * from test_table").count())
|
|
|
|
|
|
|
|
df.write.mode("overwrite").insertInto("test_table", False)
|
|
|
|
self.assertEqual(6, self.spark.sql("select * from test_table").count())
|
|
|
|
|
2018-11-14 01:51:11 -05:00
|
|
|
|
2020-07-19 21:42:33 -04:00
|
|
|
class ReadwriterV2Tests(ReusedSQLTestCase):
|
|
|
|
def test_api(self):
|
|
|
|
df = self.df
|
|
|
|
writer = df.writeTo("testcat.t")
|
|
|
|
self.assertIsInstance(writer, DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.option("property", "value"), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.options(property="value"), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.using("source"), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.partitionedBy("id"), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.partitionedBy(col("id")), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.tableProperty("foo", "bar"), DataFrameWriterV2)
|
|
|
|
|
|
|
|
def test_partitioning_functions(self):
|
|
|
|
import datetime
|
|
|
|
from pyspark.sql.functions import years, months, days, hours, bucket
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[(1, datetime.datetime(2000, 1, 1), "foo")],
|
|
|
|
("id", "ts", "value")
|
|
|
|
)
|
|
|
|
|
|
|
|
writer = df.writeTo("testcat.t")
|
|
|
|
|
|
|
|
self.assertIsInstance(writer.partitionedBy(years("ts")), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.partitionedBy(months("ts")), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.partitionedBy(days("ts")), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.partitionedBy(hours("ts")), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), DataFrameWriterV2)
|
|
|
|
self.assertIsInstance(
|
|
|
|
writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), DataFrameWriterV2
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2018-11-14 01:51:11 -05:00
|
|
|
if __name__ == "__main__":
|
|
|
|
import unittest
|
|
|
|
from pyspark.sql.tests.test_readwriter import *
|
|
|
|
|
|
|
|
try:
|
|
|
|
import xmlrunner
|
2019-06-23 20:58:17 -04:00
|
|
|
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
2018-11-14 01:51:11 -05:00
|
|
|
except ImportError:
|
2018-11-14 23:30:52 -05:00
|
|
|
testRunner = None
|
|
|
|
unittest.main(testRunner=testRunner, verbosity=2)
|