[SPARK-7738] [SQL] [PySpark] add reader and writer API in Python
cc rxin, please take a quick look, I'm working on tests. Author: Davies Liu <davies@databricks.com> Closes #6238 from davies/readwrite and squashes the following commits: c7200eb [Davies Liu] update tests 9cbf01b [Davies Liu] Merge branch 'master' of github.com:apache/spark into readwrite f0c5a04 [Davies Liu] use sqlContext.read.load 5f68bc8 [Davies Liu] update tests 6437e9a [Davies Liu] Merge branch 'master' of github.com:apache/spark into readwrite bcc6668 [Davies Liu] add reader amd writer API in Python
This commit is contained in:
parent
c12dff9b82
commit
4de74d2602
|
@ -50,8 +50,15 @@ private[spark] object PythonUtils {
|
|||
/**
|
||||
* Convert list of T into seq of T (for calling API with varargs)
|
||||
*/
|
||||
def toSeq[T](cols: JList[T]): Seq[T] = {
|
||||
cols.toList.toSeq
|
||||
def toSeq[T](vs: JList[T]): Seq[T] = {
|
||||
vs.toList.toSeq
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert list of T into array of T (for calling API with array)
|
||||
*/
|
||||
def toArray[T](vs: JList[T]): Array[T] = {
|
||||
vs.toArray().asInstanceOf[Array[T]]
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -58,6 +58,7 @@ from pyspark.sql.context import SQLContext, HiveContext
|
|||
from pyspark.sql.column import Column
|
||||
from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
|
||||
from pyspark.sql.group import GroupedData
|
||||
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
|
||||
|
||||
__all__ = [
|
||||
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
|
||||
|
|
|
@ -31,6 +31,7 @@ from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
|
|||
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
|
||||
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
from pyspark.sql.readwriter import DataFrameReader
|
||||
|
||||
try:
|
||||
import pandas
|
||||
|
@ -457,19 +458,7 @@ class SQLContext(object):
|
|||
|
||||
Optionally, a schema can be provided as the schema of the returned DataFrame.
|
||||
"""
|
||||
if path is not None:
|
||||
options["path"] = path
|
||||
if source is None:
|
||||
source = self.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
if schema is None:
|
||||
df = self._ssql_ctx.load(source, options)
|
||||
else:
|
||||
if not isinstance(schema, StructType):
|
||||
raise TypeError("schema should be StructType")
|
||||
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
|
||||
df = self._ssql_ctx.load(source, scala_datatype, options)
|
||||
return DataFrame(df, self)
|
||||
return self.read.load(path, source, schema, **options)
|
||||
|
||||
def createExternalTable(self, tableName, path=None, source=None,
|
||||
schema=None, **options):
|
||||
|
@ -567,6 +556,19 @@ class SQLContext(object):
|
|||
"""Removes all cached tables from the in-memory cache. """
|
||||
self._ssql_ctx.clearCache()
|
||||
|
||||
@property
|
||||
def read(self):
|
||||
"""
|
||||
Returns a :class:`DataFrameReader` that can be used to read data
|
||||
in as a :class:`DataFrame`.
|
||||
|
||||
::note: Experimental
|
||||
|
||||
>>> sqlContext.read
|
||||
<pyspark.sql.readwriter.DataFrameReader object at ...>
|
||||
"""
|
||||
return DataFrameReader(self)
|
||||
|
||||
|
||||
class HiveContext(SQLContext):
|
||||
"""A variant of Spark SQL that integrates with data stored in Hive.
|
||||
|
|
|
@ -29,9 +29,10 @@ from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
|
|||
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.traceback_utils import SCCallSiteSync
|
||||
from pyspark.sql.types import *
|
||||
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
|
||||
from pyspark.sql.column import Column, _to_seq, _to_java_column
|
||||
from pyspark.sql.readwriter import DataFrameWriter
|
||||
from pyspark.sql.types import *
|
||||
|
||||
__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]
|
||||
|
||||
|
@ -151,25 +152,6 @@ class DataFrame(object):
|
|||
"""
|
||||
self._jdf.insertInto(tableName, overwrite)
|
||||
|
||||
def _java_save_mode(self, mode):
|
||||
"""Returns the Java save mode based on the Python save mode represented by a string.
|
||||
"""
|
||||
jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode
|
||||
jmode = jSaveMode.ErrorIfExists
|
||||
mode = mode.lower()
|
||||
if mode == "append":
|
||||
jmode = jSaveMode.Append
|
||||
elif mode == "overwrite":
|
||||
jmode = jSaveMode.Overwrite
|
||||
elif mode == "ignore":
|
||||
jmode = jSaveMode.Ignore
|
||||
elif mode == "error":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
|
||||
return jmode
|
||||
|
||||
def saveAsTable(self, tableName, source=None, mode="error", **options):
|
||||
"""Saves the contents of this :class:`DataFrame` to a data source as a table.
|
||||
|
||||
|
@ -185,11 +167,7 @@ class DataFrame(object):
|
|||
* `error`: Throw an exception if data already exists.
|
||||
* `ignore`: Silently ignore this operation if data already exists.
|
||||
"""
|
||||
if source is None:
|
||||
source = self.sql_ctx.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
jmode = self._java_save_mode(mode)
|
||||
self._jdf.saveAsTable(tableName, source, jmode, options)
|
||||
self.write.saveAsTable(tableName, source, mode, **options)
|
||||
|
||||
def save(self, path=None, source=None, mode="error", **options):
|
||||
"""Saves the contents of the :class:`DataFrame` to a data source.
|
||||
|
@ -206,13 +184,22 @@ class DataFrame(object):
|
|||
* `error`: Throw an exception if data already exists.
|
||||
* `ignore`: Silently ignore this operation if data already exists.
|
||||
"""
|
||||
if path is not None:
|
||||
options["path"] = path
|
||||
if source is None:
|
||||
source = self.sql_ctx.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
jmode = self._java_save_mode(mode)
|
||||
self._jdf.save(source, jmode, options)
|
||||
return self.write.save(path, source, mode, **options)
|
||||
|
||||
@property
|
||||
def write(self):
|
||||
"""
|
||||
Interface for saving the content of the :class:`DataFrame` out
|
||||
into external storage.
|
||||
|
||||
:return :class:`DataFrameWriter`
|
||||
|
||||
::note: Experimental
|
||||
|
||||
>>> df.write
|
||||
<pyspark.sql.readwriter.DataFrameWriter object at ...>
|
||||
"""
|
||||
return DataFrameWriter(self)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
|
@ -411,9 +398,19 @@ class DataFrame(object):
|
|||
self._jdf.unpersist(blocking)
|
||||
return self
|
||||
|
||||
# def coalesce(self, numPartitions, shuffle=False):
|
||||
# rdd = self._jdf.coalesce(numPartitions, shuffle, None)
|
||||
# return DataFrame(rdd, self.sql_ctx)
|
||||
def coalesce(self, numPartitions):
|
||||
"""
|
||||
Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions.
|
||||
|
||||
Similar to coalesce defined on an :class:`RDD`, this operation results in a
|
||||
narrow dependency, e.g. if you go from 1000 partitions to 100 partitions,
|
||||
there will not be a shuffle, instead each of the 100 new partitions will
|
||||
claim 10 of the current partitions.
|
||||
|
||||
>>> df.coalesce(1).rdd.getNumPartitions()
|
||||
1
|
||||
"""
|
||||
return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
|
||||
|
||||
def repartition(self, numPartitions):
|
||||
"""Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions.
|
||||
|
|
338
python/pyspark/sql/readwriter.py
Normal file
338
python/pyspark/sql/readwriter.py
Normal file
|
@ -0,0 +1,338 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from py4j.java_gateway import JavaClass
|
||||
|
||||
from pyspark.sql.column import _to_seq
|
||||
from pyspark.sql.types import *
|
||||
|
||||
__all__ = ["DataFrameReader", "DataFrameWriter"]
|
||||
|
||||
|
||||
class DataFrameReader(object):
|
||||
"""
|
||||
Interface used to load a :class:`DataFrame` from external storage systems
|
||||
(e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read`
|
||||
to access this.
|
||||
|
||||
::Note: Experimental
|
||||
"""
|
||||
|
||||
def __init__(self, sqlContext):
|
||||
self._jreader = sqlContext._ssql_ctx.read()
|
||||
self._sqlContext = sqlContext
|
||||
|
||||
def _df(self, jdf):
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
return DataFrame(jdf, self._sqlContext)
|
||||
|
||||
def load(self, path=None, format=None, schema=None, **options):
|
||||
"""Loads data from a data source and returns it as a :class`DataFrame`.
|
||||
|
||||
:param path: optional string for file-system backed data sources.
|
||||
:param format: optional string for format of the data source. Default to 'parquet'.
|
||||
:param schema: optional :class:`StructType` for the input schema.
|
||||
:param options: all other string options
|
||||
"""
|
||||
jreader = self._jreader
|
||||
if format is not None:
|
||||
jreader = jreader.format(format)
|
||||
if schema is not None:
|
||||
if not isinstance(schema, StructType):
|
||||
raise TypeError("schema should be StructType")
|
||||
jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
|
||||
jreader = jreader.schema(jschema)
|
||||
for k in options:
|
||||
jreader = jreader.option(k, options[k])
|
||||
if path is not None:
|
||||
return self._df(jreader.load(path))
|
||||
else:
|
||||
return self._df(jreader.load())
|
||||
|
||||
def json(self, path, schema=None):
|
||||
"""
|
||||
Loads a JSON file (one object per line) and returns the result as
|
||||
a :class`DataFrame`.
|
||||
|
||||
If the ``schema`` parameter is not specified, this function goes
|
||||
through the input once to determine the input schema.
|
||||
|
||||
:param path: string, path to the JSON dataset.
|
||||
:param schema: an optional :class:`StructType` for the input schema.
|
||||
|
||||
>>> import tempfile, shutil
|
||||
>>> jsonFile = tempfile.mkdtemp()
|
||||
>>> shutil.rmtree(jsonFile)
|
||||
>>> with open(jsonFile, 'w') as f:
|
||||
... f.writelines(jsonStrings)
|
||||
>>> df1 = sqlContext.read.json(jsonFile)
|
||||
>>> df1.printSchema()
|
||||
root
|
||||
|-- field1: long (nullable = true)
|
||||
|-- field2: string (nullable = true)
|
||||
|-- field3: struct (nullable = true)
|
||||
| |-- field4: long (nullable = true)
|
||||
|
||||
>>> from pyspark.sql.types import *
|
||||
>>> schema = StructType([
|
||||
... StructField("field2", StringType()),
|
||||
... StructField("field3",
|
||||
... StructType([StructField("field5", ArrayType(IntegerType()))]))])
|
||||
>>> df2 = sqlContext.read.json(jsonFile, schema)
|
||||
>>> df2.printSchema()
|
||||
root
|
||||
|-- field2: string (nullable = true)
|
||||
|-- field3: struct (nullable = true)
|
||||
| |-- field5: array (nullable = true)
|
||||
| | |-- element: integer (containsNull = true)
|
||||
"""
|
||||
if schema is None:
|
||||
jdf = self._jreader.json(path)
|
||||
else:
|
||||
jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
|
||||
jdf = self._jreader.schema(jschema).json(path)
|
||||
return self._df(jdf)
|
||||
|
||||
def table(self, tableName):
|
||||
"""Returns the specified table as a :class:`DataFrame`.
|
||||
|
||||
>>> sqlContext.registerDataFrameAsTable(df, "table1")
|
||||
>>> df2 = sqlContext.read.table("table1")
|
||||
>>> sorted(df.collect()) == sorted(df2.collect())
|
||||
True
|
||||
"""
|
||||
return self._df(self._jreader.table(tableName))
|
||||
|
||||
def parquet(self, *path):
|
||||
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
|
||||
|
||||
>>> import tempfile, shutil
|
||||
>>> parquetFile = tempfile.mkdtemp()
|
||||
>>> shutil.rmtree(parquetFile)
|
||||
>>> df.saveAsParquetFile(parquetFile)
|
||||
>>> df2 = sqlContext.read.parquet(parquetFile)
|
||||
>>> sorted(df.collect()) == sorted(df2.collect())
|
||||
True
|
||||
"""
|
||||
return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path)))
|
||||
|
||||
def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None,
|
||||
predicates=None, properties={}):
|
||||
"""
|
||||
Construct a :class:`DataFrame` representing the database table accessible
|
||||
via JDBC URL `url` named `table` and connection `properties`.
|
||||
|
||||
The `column` parameter could be used to partition the table, then it will
|
||||
be retrieved in parallel based on the parameters passed to this function.
|
||||
|
||||
The `predicates` parameter gives a list expressions suitable for inclusion
|
||||
in WHERE clauses; each one defines one partition of the :class:`DataFrame`.
|
||||
|
||||
::Note: Don't create too many partitions in parallel on a large cluster;
|
||||
otherwise Spark might crash your external database systems.
|
||||
|
||||
:param url: a JDBC URL
|
||||
:param table: name of table
|
||||
:param column: the column used to partition
|
||||
:param lowerBound: the lower bound of partition column
|
||||
:param upperBound: the upper bound of the partition column
|
||||
:param numPartitions: the number of partitions
|
||||
:param predicates: a list of expressions
|
||||
:param properties: JDBC database connection arguments, a list of arbitrary string
|
||||
tag/value. Normally at least a "user" and "password" property
|
||||
should be included.
|
||||
:return: a DataFrame
|
||||
"""
|
||||
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
|
||||
for k in properties:
|
||||
jprop.setProperty(k, properties[k])
|
||||
if column is not None:
|
||||
if numPartitions is None:
|
||||
numPartitions = self._sqlContext._sc.defaultParallelism
|
||||
return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound),
|
||||
int(numPartitions), jprop))
|
||||
if predicates is not None:
|
||||
arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates)
|
||||
return self._df(self._jreader.jdbc(url, table, arr, jprop))
|
||||
return self._df(self._jreader.jdbc(url, table, jprop))
|
||||
|
||||
|
||||
class DataFrameWriter(object):
|
||||
"""
|
||||
Interface used to write a [[DataFrame]] to external storage systems
|
||||
(e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write`
|
||||
to access this.
|
||||
|
||||
::Note: Experimental
|
||||
"""
|
||||
def __init__(self, df):
|
||||
self._df = df
|
||||
self._sqlContext = df.sql_ctx
|
||||
self._jwrite = df._jdf.write()
|
||||
|
||||
def save(self, path=None, format=None, mode="error", **options):
|
||||
"""
|
||||
Saves the contents of the :class:`DataFrame` to a data source.
|
||||
|
||||
The data source is specified by the ``format`` and a set of ``options``.
|
||||
If ``format`` is not specified, the default data source configured by
|
||||
``spark.sql.sources.default`` will be used.
|
||||
|
||||
Additionally, mode is used to specify the behavior of the save operation when
|
||||
data already exists in the data source. There are four modes:
|
||||
|
||||
* `append`: Append contents of this :class:`DataFrame` to existing data.
|
||||
* `overwrite`: Overwrite existing data.
|
||||
* `error`: Throw an exception if data already exists.
|
||||
* `ignore`: Silently ignore this operation if data already exists.
|
||||
|
||||
:param path: the path in a Hadoop supported file system
|
||||
:param format: the format used to save
|
||||
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
|
||||
:param options: all other string options
|
||||
"""
|
||||
jwrite = self._jwrite.mode(mode)
|
||||
if format is not None:
|
||||
jwrite = jwrite.format(format)
|
||||
for k in options:
|
||||
jwrite = jwrite.option(k, options[k])
|
||||
if path is None:
|
||||
jwrite.save()
|
||||
else:
|
||||
jwrite.save(path)
|
||||
|
||||
def saveAsTable(self, name, format=None, mode="error", **options):
|
||||
"""
|
||||
Saves the contents of this :class:`DataFrame` to a data source as a table.
|
||||
|
||||
The data source is specified by the ``source`` and a set of ``options``.
|
||||
If ``source`` is not specified, the default data source configured by
|
||||
``spark.sql.sources.default`` will be used.
|
||||
|
||||
Additionally, mode is used to specify the behavior of the saveAsTable operation when
|
||||
table already exists in the data source. There are four modes:
|
||||
|
||||
* `append`: Append contents of this :class:`DataFrame` to existing data.
|
||||
* `overwrite`: Overwrite existing data.
|
||||
* `error`: Throw an exception if data already exists.
|
||||
* `ignore`: Silently ignore this operation if data already exists.
|
||||
|
||||
:param name: the table name
|
||||
:param format: the format used to save
|
||||
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
|
||||
:param options: all other string options
|
||||
"""
|
||||
jwrite = self._jwrite.mode(mode)
|
||||
if format is not None:
|
||||
jwrite = jwrite.format(format)
|
||||
for k in options:
|
||||
jwrite = jwrite.option(k, options[k])
|
||||
return jwrite.saveAsTable(name)
|
||||
|
||||
def json(self, path, mode="error"):
|
||||
"""
|
||||
Saves the content of the :class:`DataFrame` in JSON format at the
|
||||
specified path.
|
||||
|
||||
Additionally, mode is used to specify the behavior of the save operation when
|
||||
data already exists in the data source. There are four modes:
|
||||
|
||||
* `append`: Append contents of this :class:`DataFrame` to existing data.
|
||||
* `overwrite`: Overwrite existing data.
|
||||
* `error`: Throw an exception if data already exists.
|
||||
* `ignore`: Silently ignore this operation if data already exists.
|
||||
|
||||
:param path: the path in any Hadoop supported file system
|
||||
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
|
||||
"""
|
||||
return self._jwrite.mode(mode).json(path)
|
||||
|
||||
def parquet(self, path, mode="error"):
|
||||
"""
|
||||
Saves the content of the :class:`DataFrame` in Parquet format at the
|
||||
specified path.
|
||||
|
||||
Additionally, mode is used to specify the behavior of the save operation when
|
||||
data already exists in the data source. There are four modes:
|
||||
|
||||
* `append`: Append contents of this :class:`DataFrame` to existing data.
|
||||
* `overwrite`: Overwrite existing data.
|
||||
* `error`: Throw an exception if data already exists.
|
||||
* `ignore`: Silently ignore this operation if data already exists.
|
||||
|
||||
:param path: the path in any Hadoop supported file system
|
||||
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
|
||||
"""
|
||||
return self._jwrite.mode(mode).parquet(path)
|
||||
|
||||
def jdbc(self, url, table, mode="error", properties={}):
|
||||
"""
|
||||
Saves the content of the :class:`DataFrame` to a external database table
|
||||
via JDBC.
|
||||
|
||||
In the case the table already exists in the external database,
|
||||
behavior of this function depends on the save mode, specified by the `mode`
|
||||
function (default to throwing an exception). There are four modes:
|
||||
|
||||
* `append`: Append contents of this :class:`DataFrame` to existing data.
|
||||
* `overwrite`: Overwrite existing data.
|
||||
* `error`: Throw an exception if data already exists.
|
||||
* `ignore`: Silently ignore this operation if data already exists.
|
||||
|
||||
:param url: a JDBC URL of the form `jdbc:subprotocol:subname`
|
||||
:param table: Name of the table in the external database.
|
||||
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
|
||||
:param properties: JDBC database connection arguments, a list of
|
||||
arbitrary string tag/value. Normally at least a
|
||||
"user" and "password" property should be included.
|
||||
"""
|
||||
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
|
||||
for k in properties:
|
||||
jprop.setProperty(k, properties[k])
|
||||
self._jwrite.mode(mode).jdbc(url, table, jprop)
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.sql import Row, SQLContext
|
||||
import pyspark.sql.readwriter
|
||||
globs = pyspark.sql.readwriter.__dict__.copy()
|
||||
sc = SparkContext('local[4]', 'PythonTest')
|
||||
globs['sc'] = sc
|
||||
globs['sqlContext'] = SQLContext(sc)
|
||||
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
|
||||
.toDF(StructType([StructField('age', IntegerType()),
|
||||
StructField('name', StringType())]))
|
||||
jsonStrings = [
|
||||
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
|
||||
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
|
||||
'"field6":[{"field7": "row2"}]}',
|
||||
'{"field1" : null, "field2": "row3", '
|
||||
'"field3":{"field4":33, "field5": []}}'
|
||||
]
|
||||
globs['jsonStrings'] = jsonStrings
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.readwriter, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
||||
globs['sc'].stop()
|
||||
if failure_count:
|
||||
exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
|
@ -485,29 +485,29 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
df = self.df
|
||||
tmpPath = tempfile.mkdtemp()
|
||||
shutil.rmtree(tmpPath)
|
||||
df.save(tmpPath, "org.apache.spark.sql.json", "error")
|
||||
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
|
||||
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
||||
df.write.json(tmpPath)
|
||||
actual = self.sqlCtx.read.json(tmpPath)
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
|
||||
schema = StructType([StructField("value", StringType(), True)])
|
||||
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
|
||||
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
|
||||
actual = self.sqlCtx.read.json(tmpPath, schema)
|
||||
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
||||
|
||||
df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
|
||||
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
|
||||
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
||||
df.write.json(tmpPath, "overwrite")
|
||||
actual = self.sqlCtx.read.json(tmpPath)
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
|
||||
df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
|
||||
noUse="this options will not be used in save.")
|
||||
actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
|
||||
noUse="this options will not be used in load.")
|
||||
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
||||
df.write.save(format="json", mode="overwrite", path=tmpPath,
|
||||
noUse="this options will not be used in save.")
|
||||
actual = self.sqlCtx.read.load(format="json", path=tmpPath,
|
||||
noUse="this options will not be used in load.")
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
|
||||
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
||||
actual = self.sqlCtx.load(path=tmpPath)
|
||||
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
||||
|
||||
shutil.rmtree(tmpPath)
|
||||
|
@ -767,51 +767,44 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
|
|||
df = self.df
|
||||
tmpPath = tempfile.mkdtemp()
|
||||
shutil.rmtree(tmpPath)
|
||||
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
|
||||
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
|
||||
"org.apache.spark.sql.json")
|
||||
self.assertTrue(
|
||||
sorted(df.collect()) ==
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
||||
self.assertTrue(
|
||||
sorted(df.collect()) ==
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
||||
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
||||
df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
|
||||
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json")
|
||||
self.assertEqual(sorted(df.collect()),
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
||||
self.assertEqual(sorted(df.collect()),
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
||||
|
||||
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
|
||||
df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
|
||||
schema = StructType([StructField("value", StringType(), True)])
|
||||
actual = self.sqlCtx.createExternalTable("externalJsonTable",
|
||||
source="org.apache.spark.sql.json",
|
||||
actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json",
|
||||
schema=schema, path=tmpPath,
|
||||
noUse="this options will not be used")
|
||||
self.assertTrue(
|
||||
sorted(df.collect()) ==
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
||||
self.assertTrue(
|
||||
sorted(df.select("value").collect()) ==
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
||||
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
|
||||
self.assertEqual(sorted(df.collect()),
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
||||
self.assertEqual(sorted(df.select("value").collect()),
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
||||
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
|
||||
self.sqlCtx.sql("DROP TABLE savedJsonTable")
|
||||
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
||||
|
||||
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
|
||||
"org.apache.spark.sql.parquet")
|
||||
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
|
||||
df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
|
||||
df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
|
||||
actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
|
||||
self.assertTrue(
|
||||
sorted(df.collect()) ==
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
||||
self.assertTrue(
|
||||
sorted(df.collect()) ==
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
||||
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
|
||||
self.assertEqual(sorted(df.collect()),
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
|
||||
self.assertEqual(sorted(df.collect()),
|
||||
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
|
||||
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
|
||||
self.sqlCtx.sql("DROP TABLE savedJsonTable")
|
||||
self.sqlCtx.sql("DROP TABLE externalJsonTable")
|
||||
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
|
||||
|
||||
shutil.rmtree(tmpPath)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in a new issue