[SPARK-15982][SPARK-16009][SPARK-16007][SQL] Harmonize the behavior of DataFrameReader.text/csv/json/parquet/orc

## What changes were proposed in this pull request?

Issues with current reader behavior.
- `text()` without args returns an empty DF with no columns -> inconsistent, its expected that text will always return a DF with `value` string field,
- `textFile()` without args fails with exception because of the above reason, it expected the DF returned by `text()` to have a `value` field.
- `orc()` does not have var args, inconsistent with others
- `json(single-arg)` was removed, but that caused source compatibility issues - [SPARK-16009](https://issues.apache.org/jira/browse/SPARK-16009)
- user specified schema was not respected when `text/csv/...` were used with no args - [SPARK-16007](https://issues.apache.org/jira/browse/SPARK-16007)

The solution I am implementing is to do the following.
- For each format, there will be a single argument method, and a vararg method. For json, parquet, csv, text, this means adding json(string), etc.. For orc, this means adding orc(varargs).
- Remove the special handling of text(), csv(), etc. that returns empty dataframe with no fields. Rather pass on the empty sequence of paths to the datasource, and let each datasource handle it right. For e.g, text data source, should return empty DF with schema (value: string)
- Deduped docs and fixed their formatting.

## How was this patch tested?
Added new unit tests for Scala and Java tests

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #13727 from tdas/SPARK-15982.
This commit is contained in:
Tathagata Das 2016-06-20 14:52:28 -07:00 committed by Shixiong Zhu
parent 6df8e38860
commit b99129cc45
3 changed files with 420 additions and 56 deletions

View file

@ -119,13 +119,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def load(): DataFrame = {
val dataSource =
DataSource(
sparkSession,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation()))
load(Seq.empty: _*) // force invocation of `load(...varargs...)`
}
/**
@ -135,7 +129,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def load(path: String): DataFrame = {
option("path", path).load()
load(Seq(path): _*) // force invocation of `load(...varargs...)`
}
/**
@ -146,18 +140,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*/
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
if (paths.isEmpty) {
sparkSession.emptyDataFrame
} else {
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
}
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
}
/**
* Construct a [[DataFrame]] representing the database table accessible via JDBC URL
* url named table and connection properties.
@ -245,6 +236,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
sparkSession.baseRelationToDataFrame(relation)
}
/**
* Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
* See the documentation on the overloaded `json()` method with varargs for more details.
*
* @since 1.4.0
*/
def json(path: String): DataFrame = {
// This method ensures that calls that explicit need single argument works, see SPARK-16009
json(Seq(path): _*)
}
/**
* Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
*
@ -252,6 +254,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* schema in advance, use the version that specifies the schema to avoid the extra scan.
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
* <ul>
* <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
* <li>`prefersDecimal` (default `false`): infers all floating-point values as a decimal
* type. If the values do not fit in decimal, then it infers them as doubles.</li>
@ -266,17 +269,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.</li>
* <ul>
* <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the
* malformed string into a new field configured by `columnNameOfCorruptRecord`. When
* <li> - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
* the malformed string into a new field configured by `columnNameOfCorruptRecord`. When
* a schema is set by user, it sets `null` for extra fields.</li>
* <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
* <li> - `DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li> - `FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
* <li>`columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
*
* @since 1.6.0
* </ul>
* @since 2.0.0
*/
@scala.annotation.varargs
def json(paths: String*): DataFrame = format("json").load(paths : _*)
@ -326,6 +329,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions))(sparkSession))
}
/**
* Loads a CSV file and returns the result as a [[DataFrame]]. See the documentation on the
* other overloaded `csv()` method for more details.
*
* @since 2.0.0
*/
def csv(path: String): DataFrame = {
// This method ensures that calls that explicit need single argument works, see SPARK-16009
csv(Seq(path): _*)
}
/**
* Loads a CSV file and returns the result as a [[DataFrame]].
*
@ -334,6 +348,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* specify the schema explicitly using [[schema]].
*
* You can set the following CSV-specific options to deal with CSV files:
* <ul>
* <li>`sep` (default `,`): sets the single character as a separator for each
* field and value.</li>
* <li>`encoding` (default `UTF-8`): decodes the CSV files by the given encoding
@ -370,26 +385,37 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.</li>
* <ul>
* <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When
* <li> - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When
* a schema is set by user, it sets `null` for extra fields.</li>
* <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
* <li> - `DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li> - `FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
* </ul>
*
* @since 2.0.0
*/
@scala.annotation.varargs
def csv(paths: String*): DataFrame = format("csv").load(paths : _*)
/**
* Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty
* [[DataFrame]] if no paths are passed in.
* Loads a Parquet file, returning the result as a [[DataFrame]]. See the documentation
* on the other overloaded `parquet()` method for more details.
*
* @since 2.0.0
*/
def parquet(path: String): DataFrame = {
// This method ensures that calls that explicit need single argument works, see SPARK-16009
parquet(Seq(path): _*)
}
/**
* Loads a Parquet file, returning the result as a [[DataFrame]].
*
* You can set the following Parquet-specific option(s) for reading Parquet files:
* <ul>
* <li>`mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets
* whether we should merge schemas collected from all Parquet part-files. This will override
* `spark.sql.parquet.mergeSchema`.</li>
*
* </ul>
* @since 1.4.0
*/
@scala.annotation.varargs
@ -404,7 +430,20 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.5.0
* @note Currently, this method can only be used after enabling Hive support.
*/
def orc(path: String): DataFrame = format("orc").load(path)
def orc(path: String): DataFrame = {
// This method ensures that calls that explicit need single argument works, see SPARK-16009
orc(Seq(path): _*)
}
/**
* Loads an ORC file and returns the result as a [[DataFrame]].
*
* @param paths input paths
* @since 2.0.0
* @note Currently, this method can only be used after enabling Hive support.
*/
@scala.annotation.varargs
def orc(paths: String*): DataFrame = format("orc").load(paths: _*)
/**
* Returns the specified table as a [[DataFrame]].
@ -417,6 +456,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)))
}
/**
* Loads text files and returns a [[DataFrame]] whose schema starts with a string column named
* "value", and followed by partitioned columns if there are any. See the documentation on
* the other overloaded `text()` method for more details.
*
* @since 2.0.0
*/
def text(path: String): DataFrame = {
// This method ensures that calls that explicit need single argument works, see SPARK-16009
text(Seq(path): _*)
}
/**
* Loads text files and returns a [[DataFrame]] whose schema starts with a string column named
* "value", and followed by partitioned columns if there are any.
@ -430,12 +481,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* spark.read().text("/path/to/spark/README.md")
* }}}
*
* @param paths input path
* @param paths input paths
* @since 1.6.0
*/
@scala.annotation.varargs
def text(paths: String*): DataFrame = format("text").load(paths : _*)
/**
* Loads text files and returns a [[Dataset]] of String. See the documentation on the
* other overloaded `textFile()` method for more details.
* @since 2.0.0
*/
def textFile(path: String): Dataset[String] = {
// This method ensures that calls that explicit need single argument works, see SPARK-16009
textFile(Seq(path): _*)
}
/**
* Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset
* contains a single string column named "value".
@ -457,6 +518,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*/
@scala.annotation.varargs
def textFile(paths: String*): Dataset[String] = {
if (userSpecifiedSchema.nonEmpty) {
throw new AnalysisException("User specified schema not supported with `textFile`")
}
text(paths : _*).select("value").as[String](sparkSession.implicits.newStringEncoder)
}

View file

@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql;
import java.io.File;
import java.util.HashMap;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.Utils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
public class JavaDataFrameReaderWriterSuite {
private SparkSession spark = new TestSparkSession();
private StructType schema = new StructType().add("s", "string");
private transient String input;
private transient String output;
@Before
public void setUp() {
input = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "input").toString();
File f = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "output");
f.delete();
output = f.toString();
}
@After
public void tearDown() {
spark.stop();
spark = null;
}
@Test
public void testFormatAPI() {
spark
.read()
.format("org.apache.spark.sql.test")
.load()
.write()
.format("org.apache.spark.sql.test")
.save();
}
@Test
public void testOptionsAPI() {
HashMap<String, String> map = new HashMap<String, String>();
map.put("e", "1");
spark
.read()
.option("a", "1")
.option("b", 1)
.option("c", 1.0)
.option("d", true)
.options(map)
.text()
.write()
.option("a", "1")
.option("b", 1)
.option("c", 1.0)
.option("d", true)
.options(map)
.format("org.apache.spark.sql.test")
.save();
}
@Test
public void testSaveModeAPI() {
spark
.range(10)
.write()
.format("org.apache.spark.sql.test")
.mode(SaveMode.ErrorIfExists)
.save();
}
@Test
public void testLoadAPI() {
spark.read().format("org.apache.spark.sql.test").load();
spark.read().format("org.apache.spark.sql.test").load(input);
spark.read().format("org.apache.spark.sql.test").load(input, input, input);
spark.read().format("org.apache.spark.sql.test").load(new String[]{input, input});
}
@Test
public void testTextAPI() {
spark.read().text();
spark.read().text(input);
spark.read().text(input, input, input);
spark.read().text(new String[]{input, input})
.write().text(output);
}
@Test
public void testTextFileAPI() {
spark.read().textFile();
spark.read().textFile(input);
spark.read().textFile(input, input, input);
spark.read().textFile(new String[]{input, input});
}
@Test
public void testCsvAPI() {
spark.read().schema(schema).csv();
spark.read().schema(schema).csv(input);
spark.read().schema(schema).csv(input, input, input);
spark.read().schema(schema).csv(new String[]{input, input})
.write().csv(output);
}
@Test
public void testJsonAPI() {
spark.read().schema(schema).json();
spark.read().schema(schema).json(input);
spark.read().schema(schema).json(input, input, input);
spark.read().schema(schema).json(new String[]{input, input})
.write().json(output);
}
@Test
public void testParquetAPI() {
spark.read().schema(schema).parquet();
spark.read().schema(schema).parquet(input);
spark.read().schema(schema).parquet(input, input, input);
spark.read().schema(schema).parquet(new String[] { input, input })
.write().parquet(output);
}
/**
* This only tests whether API compiles, but does not run it as orc()
* cannot be run without Hive classes.
*/
public void testOrcAPI() {
spark.read().schema(schema).orc();
spark.read().schema(schema).orc(input);
spark.read().schema(schema).orc(input, input, input);
spark.read().schema(schema).orc(new String[]{input, input})
.write().orc(output);
}
}

View file

@ -17,6 +17,10 @@
package org.apache.spark.sql.test
import java.io.File
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
@ -79,10 +83,19 @@ class DefaultSource
}
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
private def newMetadataDir =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
private val userSchema = new StructType().add("s", StringType)
private val textSchema = new StructType().add("value", StringType)
private val data = Seq("1", "2", "3")
private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath
private implicit var enc: Encoder[String] = _
before {
enc = spark.implicits.newStringEncoder
Utils.deleteRecursively(new File(dir))
}
test("writeStream cannot be called on non-streaming datasets") {
val e = intercept[AnalysisException] {
@ -157,24 +170,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
assert(LastOptions.saveMode === SaveMode.ErrorIfExists)
}
test("paths") {
val df = spark.read
.format("org.apache.spark.sql.test")
.option("checkpointLocation", newMetadataDir)
.load("/test")
assert(LastOptions.parameters("path") == "/test")
LastOptions.clear()
df.write
.format("org.apache.spark.sql.test")
.option("checkpointLocation", newMetadataDir)
.save("/test")
assert(LastOptions.parameters("path") == "/test")
}
test("test different data types for options") {
val df = spark.read
.format("org.apache.spark.sql.test")
@ -193,7 +188,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
.option("intOpt", 56)
.option("boolOpt", false)
.option("doubleOpt", 6.7)
.option("checkpointLocation", newMetadataDir)
.save("/test")
assert(LastOptions.parameters("intOpt") == "56")
@ -228,4 +222,152 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
}
}
}
test("load API") {
spark.read.format("org.apache.spark.sql.test").load()
spark.read.format("org.apache.spark.sql.test").load(dir)
spark.read.format("org.apache.spark.sql.test").load(dir, dir, dir)
spark.read.format("org.apache.spark.sql.test").load(Seq(dir, dir): _*)
Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
}
test("text - API and behavior regarding schema") {
// Writer
spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
testRead(spark.read.text(dir), data, textSchema)
// Reader, without user specified schema
testRead(spark.read.text(), Seq.empty, textSchema)
testRead(spark.read.text(dir, dir, dir), data ++ data ++ data, textSchema)
testRead(spark.read.text(Seq(dir, dir): _*), data ++ data, textSchema)
// Test explicit calls to single arg method - SPARK-16009
testRead(Option(dir).map(spark.read.text).get, data, textSchema)
// Reader, with user specified schema, should just apply user schema on the file data
testRead(spark.read.schema(userSchema).text(), Seq.empty, userSchema)
testRead(spark.read.schema(userSchema).text(dir), data, userSchema)
testRead(spark.read.schema(userSchema).text(dir, dir), data ++ data, userSchema)
testRead(spark.read.schema(userSchema).text(Seq(dir, dir): _*), data ++ data, userSchema)
}
test("textFile - API and behavior regarding schema") {
spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
// Reader, without user specified schema
testRead(spark.read.textFile().toDF(), Seq.empty, textSchema)
testRead(spark.read.textFile(dir).toDF(), data, textSchema)
testRead(spark.read.textFile(dir, dir).toDF(), data ++ data, textSchema)
testRead(spark.read.textFile(Seq(dir, dir): _*).toDF(), data ++ data, textSchema)
// Test explicit calls to single arg method - SPARK-16009
testRead(Option(dir).map(spark.read.text).get, data, textSchema)
// Reader, with user specified schema, should just apply user schema on the file data
val e = intercept[AnalysisException] { spark.read.schema(userSchema).textFile() }
assert(e.getMessage.toLowerCase.contains("user specified schema not supported"))
intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) }
intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, dir) }
intercept[AnalysisException] { spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) }
}
test("csv - API and behavior regarding schema") {
// Writer
spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).csv(dir)
val df = spark.read.csv(dir)
checkAnswer(df, spark.createDataset(data).toDF())
val schema = df.schema
// Reader, without user specified schema
intercept[IllegalArgumentException] {
testRead(spark.read.csv(), Seq.empty, schema)
}
testRead(spark.read.csv(dir), data, schema)
testRead(spark.read.csv(dir, dir), data ++ data, schema)
testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema)
// Test explicit calls to single arg method - SPARK-16009
testRead(Option(dir).map(spark.read.csv).get, data, schema)
// Reader, with user specified schema, should just apply user schema on the file data
testRead(spark.read.schema(userSchema).csv(), Seq.empty, userSchema)
testRead(spark.read.schema(userSchema).csv(dir), data, userSchema)
testRead(spark.read.schema(userSchema).csv(dir, dir), data ++ data, userSchema)
testRead(spark.read.schema(userSchema).csv(Seq(dir, dir): _*), data ++ data, userSchema)
}
test("json - API and behavior regarding schema") {
// Writer
spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).json(dir)
val df = spark.read.json(dir)
checkAnswer(df, spark.createDataset(data).toDF())
val schema = df.schema
// Reader, without user specified schema
intercept[AnalysisException] {
testRead(spark.read.json(), Seq.empty, schema)
}
testRead(spark.read.json(dir), data, schema)
testRead(spark.read.json(dir, dir), data ++ data, schema)
testRead(spark.read.json(Seq(dir, dir): _*), data ++ data, schema)
// Test explicit calls to single arg method - SPARK-16009
testRead(Option(dir).map(spark.read.json).get, data, schema)
// Reader, with user specified schema, data should be nulls as schema in file different
// from user schema
val expData = Seq[String](null, null, null)
testRead(spark.read.schema(userSchema).json(), Seq.empty, userSchema)
testRead(spark.read.schema(userSchema).json(dir), expData, userSchema)
testRead(spark.read.schema(userSchema).json(dir, dir), expData ++ expData, userSchema)
testRead(spark.read.schema(userSchema).json(Seq(dir, dir): _*), expData ++ expData, userSchema)
}
test("parquet - API and behavior regarding schema") {
// Writer
spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).parquet(dir)
val df = spark.read.parquet(dir)
checkAnswer(df, spark.createDataset(data).toDF())
val schema = df.schema
// Reader, without user specified schema
intercept[AnalysisException] {
testRead(spark.read.parquet(), Seq.empty, schema)
}
testRead(spark.read.parquet(dir), data, schema)
testRead(spark.read.parquet(dir, dir), data ++ data, schema)
testRead(spark.read.parquet(Seq(dir, dir): _*), data ++ data, schema)
// Test explicit calls to single arg method - SPARK-16009
testRead(Option(dir).map(spark.read.parquet).get, data, schema)
// Reader, with user specified schema, data should be nulls as schema in file different
// from user schema
val expData = Seq[String](null, null, null)
testRead(spark.read.schema(userSchema).parquet(), Seq.empty, userSchema)
testRead(spark.read.schema(userSchema).parquet(dir), expData, userSchema)
testRead(spark.read.schema(userSchema).parquet(dir, dir), expData ++ expData, userSchema)
testRead(
spark.read.schema(userSchema).parquet(Seq(dir, dir): _*), expData ++ expData, userSchema)
}
/**
* This only tests whether API compiles, but does not run it as orc()
* cannot be run without Hive classes.
*/
ignore("orc - API") {
// Reader, with user specified schema
// Refer to csv-specific test suites for behavior without user specified schema
spark.read.schema(userSchema).orc()
spark.read.schema(userSchema).orc(dir)
spark.read.schema(userSchema).orc(dir, dir, dir)
spark.read.schema(userSchema).orc(Seq(dir, dir): _*)
Option(dir).map(spark.read.schema(userSchema).orc)
// Writer
spark.range(10).write.orc(dir)
}
private def testRead(
df: => DataFrame,
expectedResult: Seq[String],
expectedSchema: StructType): Unit = {
checkAnswer(df, spark.createDataset(expectedResult).toDF())
assert(df.schema === expectedSchema)
}
}