[SPARK-8532] [SQL] In Python's DataFrameWriter, save/saveAsTable/json/parquet/jdbc always override mode
https://issues.apache.org/jira/browse/SPARK-8532 This PR has two changes. First, it fixes the bug that save actions (i.e. `save/saveAsTable/json/parquet/jdbc`) always override mode. Second, it adds input argument `partitionBy` to `save/saveAsTable/parquet`. Author: Yin Huai <yhuai@databricks.com> Closes #6937 from yhuai/SPARK-8532 and squashes the following commits: f972d5d [Yin Huai] davies's comment. d37abd2 [Yin Huai] style. d21290a [Yin Huai] Python doc. 889eb25 [Yin Huai] Minor refactoring and add partitionBy to save, saveAsTable, and parquet. 7fbc24b [Yin Huai] Use None instead of "error" as the default value of mode since JVM-side already uses "error" as the default value. d696dff [Yin Huai] Python style. 88eb6c4 [Yin Huai] If mode is "error", do not call mode method. c40c461 [Yin Huai] Regression test.
This commit is contained in:
parent
da7bbb9435
commit
5ab9fcfb01
|
@ -218,7 +218,10 @@ class DataFrameWriter(object):
|
|||
|
||||
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
|
||||
"""
|
||||
self._jwrite = self._jwrite.mode(saveMode)
|
||||
# At the JVM side, the default value of mode is already set to "error".
|
||||
# So, if the given saveMode is None, we will not call JVM-side's mode method.
|
||||
if saveMode is not None:
|
||||
self._jwrite = self._jwrite.mode(saveMode)
|
||||
return self
|
||||
|
||||
@since(1.4)
|
||||
|
@ -253,11 +256,12 @@ class DataFrameWriter(object):
|
|||
"""
|
||||
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
|
||||
cols = cols[0]
|
||||
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
|
||||
if len(cols) > 0:
|
||||
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
|
||||
return self
|
||||
|
||||
@since(1.4)
|
||||
def save(self, path=None, format=None, mode="error", **options):
|
||||
def save(self, path=None, format=None, mode=None, partitionBy=(), **options):
|
||||
"""Saves the contents of the :class:`DataFrame` to a data source.
|
||||
|
||||
The data source is specified by the ``format`` and a set of ``options``.
|
||||
|
@ -272,11 +276,12 @@ class DataFrameWriter(object):
|
|||
* ``overwrite``: Overwrite existing data.
|
||||
* ``ignore``: Silently ignore this operation if data already exists.
|
||||
* ``error`` (default case): Throw an exception if data already exists.
|
||||
:param partitionBy: names of partitioning columns
|
||||
:param options: all other string options
|
||||
|
||||
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
|
||||
"""
|
||||
self.mode(mode).options(**options)
|
||||
self.partitionBy(partitionBy).mode(mode).options(**options)
|
||||
if format is not None:
|
||||
self.format(format)
|
||||
if path is None:
|
||||
|
@ -296,7 +301,7 @@ class DataFrameWriter(object):
|
|||
self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
|
||||
|
||||
@since(1.4)
|
||||
def saveAsTable(self, name, format=None, mode="error", **options):
|
||||
def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options):
|
||||
"""Saves the content of the :class:`DataFrame` as the specified table.
|
||||
|
||||
In the case the table already exists, behavior of this function depends on the
|
||||
|
@ -312,15 +317,16 @@ class DataFrameWriter(object):
|
|||
:param name: the table name
|
||||
:param format: the format used to save
|
||||
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
|
||||
:param partitionBy: names of partitioning columns
|
||||
:param options: all other string options
|
||||
"""
|
||||
self.mode(mode).options(**options)
|
||||
self.partitionBy(partitionBy).mode(mode).options(**options)
|
||||
if format is not None:
|
||||
self.format(format)
|
||||
self._jwrite.saveAsTable(name)
|
||||
|
||||
@since(1.4)
|
||||
def json(self, path, mode="error"):
|
||||
def json(self, path, mode=None):
|
||||
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path.
|
||||
|
||||
:param path: the path in any Hadoop supported file system
|
||||
|
@ -333,10 +339,10 @@ class DataFrameWriter(object):
|
|||
|
||||
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
|
||||
"""
|
||||
self._jwrite.mode(mode).json(path)
|
||||
self.mode(mode)._jwrite.json(path)
|
||||
|
||||
@since(1.4)
|
||||
def parquet(self, path, mode="error"):
|
||||
def parquet(self, path, mode=None, partitionBy=()):
|
||||
"""Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
|
||||
|
||||
:param path: the path in any Hadoop supported file system
|
||||
|
@ -346,13 +352,15 @@ class DataFrameWriter(object):
|
|||
* ``overwrite``: Overwrite existing data.
|
||||
* ``ignore``: Silently ignore this operation if data already exists.
|
||||
* ``error`` (default case): Throw an exception if data already exists.
|
||||
:param partitionBy: names of partitioning columns
|
||||
|
||||
>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
|
||||
"""
|
||||
self._jwrite.mode(mode).parquet(path)
|
||||
self.partitionBy(partitionBy).mode(mode)
|
||||
self._jwrite.parquet(path)
|
||||
|
||||
@since(1.4)
|
||||
def jdbc(self, url, table, mode="error", properties={}):
|
||||
def jdbc(self, url, table, mode=None, properties={}):
|
||||
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.
|
||||
|
||||
.. note:: Don't create too many partitions in parallel on a large cluster;\
|
||||
|
|
|
@ -539,6 +539,38 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
|
||||
shutil.rmtree(tmpPath)
|
||||
|
||||
def test_save_and_load_builder(self):
|
||||
df = self.df
|
||||
tmpPath = tempfile.mkdtemp()
|
||||
shutil.rmtree(tmpPath)
|
||||
df.write.json(tmpPath)
|
||||
actual = self.sqlCtx.read.json(tmpPath)
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
|
||||
schema = StructType([StructField("value", StringType(), True)])
|
||||
actual = self.sqlCtx.read.json(tmpPath, schema)
|
||||
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
||||
|
||||
df.write.mode("overwrite").json(tmpPath)
|
||||
actual = self.sqlCtx.read.json(tmpPath)
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
|
||||
df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
|
||||
.format("json").save(path=tmpPath)
|
||||
actual =\
|
||||
self.sqlCtx.read.format("json")\
|
||||
.load(path=tmpPath, noUse="this options will not be used in load.")
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
|
||||
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
||||
actual = self.sqlCtx.load(path=tmpPath)
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
||||
|
||||
shutil.rmtree(tmpPath)
|
||||
|
||||
def test_help_command(self):
|
||||
# Regression test for SPARK-5464
|
||||
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
||||
|
|
Loading…
Reference in a new issue