[SPARK-20590][SQL] Use Spark internal datasource if multiples are found for the same shorten name
## What changes were proposed in this pull request? One of the common usability problems around reading data in spark (particularly CSV) is that there can often be a conflict between different readers in the classpath. As an example, if someone launches a 2.x spark shell with the spark-csv package in the classpath, Spark currently fails in an extremely unfriendly way (see databricks/spark-csv#367): ```bash ./bin/spark-shell --packages com.databricks:spark-csv_2.11:1.5.0 scala> val df = spark.read.csv("/foo/bar.csv") java.lang.RuntimeException: Multiple sources found for csv (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat, com.databricks.spark.csv.DefaultSource15), please specify the fully qualified class name. at scala.sys.package$.error(package.scala:27) at org.apache.spark.sql.execution.datasources.DataSource$.lookupDataSource(DataSource.scala:574) at org.apache.spark.sql.execution.datasources.DataSource.providingClass$lzycompute(DataSource.scala:85) at org.apache.spark.sql.execution.datasources.DataSource.providingClass(DataSource.scala:85) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:295) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:178) at org.apache.spark.sql.DataFrameReader.csv(DataFrameReader.scala:533) at org.apache.spark.sql.DataFrameReader.csv(DataFrameReader.scala:412) ... 48 elided ``` This PR proposes a simple way of fixing this error by picking up the internal datasource if there is single (the datasource that has "org.apache.spark" prefix). ```scala scala> spark.range(1).write.format("csv").mode("overwrite").save("/tmp/abc") 17/05/10 09:47:44 WARN DataSource: Multiple sources found for csv (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat, com.databricks.spark.csv.DefaultSource15), defaulting to the internal datasource (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat). ``` ```scala scala> spark.range(1).write.format("Csv").mode("overwrite").save("/tmp/abc") 17/05/10 09:47:52 WARN DataSource: Multiple sources found for Csv (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat, com.databricks.spark.csv.DefaultSource15), defaulting to the internal datasource (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat). ``` ## How was this patch tested? Manually tested as below: ```bash ./bin/spark-shell --packages com.databricks:spark-csv_2.11:1.5.0 ``` ```scala spark.sparkContext.setLogLevel("WARN") ``` **positive cases**: ```scala scala> spark.range(1).write.format("csv").mode("overwrite").save("/tmp/abc") 17/05/10 09:47:44 WARN DataSource: Multiple sources found for csv (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat, com.databricks.spark.csv.DefaultSource15), defaulting to the internal datasource (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat). ``` ```scala scala> spark.range(1).write.format("Csv").mode("overwrite").save("/tmp/abc") 17/05/10 09:47:52 WARN DataSource: Multiple sources found for Csv (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat, com.databricks.spark.csv.DefaultSource15), defaulting to the internal datasource (org.apache.spark.sql.execution.datasources.csv.CSVFileFormat). ``` (newlines were inserted for readability). ```scala scala> spark.range(1).write.format("com.databricks.spark.csv").mode("overwrite").save("/tmp/abc") ``` ```scala scala> spark.range(1).write.format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat").mode("overwrite").save("/tmp/abc") ``` **negative cases**: ```scala scala> spark.range(1).write.format("com.databricks.spark.csv.CsvRelation").save("/tmp/abc") java.lang.InstantiationException: com.databricks.spark.csv.CsvRelation ... ``` ```scala scala> spark.range(1).write.format("com.databricks.spark.csv.CsvRelatio").save("/tmp/abc") java.lang.ClassNotFoundException: Failed to find data source: com.databricks.spark.csv.CsvRelatio. Please find packages at http://spark.apache.org/third-party-projects.html ... ``` Author: hyukjinkwon <gurwls223@gmail.com> Closes #17916 from HyukjinKwon/datasource-detect.
This commit is contained in:
parent
771abeb46f
commit
3d2131ab4d
|
@ -481,7 +481,7 @@ case class DataSource(
|
|||
}
|
||||
}
|
||||
|
||||
object DataSource {
|
||||
object DataSource extends Logging {
|
||||
|
||||
/** A map to maintain backward compatibility in case we move data sources around. */
|
||||
private val backwardCompatibilityMap: Map[String, String] = {
|
||||
|
@ -570,10 +570,19 @@ object DataSource {
|
|||
// there is exactly one registered alias
|
||||
head.getClass
|
||||
case sources =>
|
||||
// There are multiple registered aliases for the input
|
||||
sys.error(s"Multiple sources found for $provider1 " +
|
||||
s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
|
||||
"please specify the fully qualified class name.")
|
||||
// There are multiple registered aliases for the input. If there is single datasource
|
||||
// that has "org.apache.spark" package in the prefix, we use it considering it is an
|
||||
// internal datasource within Spark.
|
||||
val sourceNames = sources.map(_.getClass.getName)
|
||||
val internalSources = sources.filter(_.getClass.getName.startsWith("org.apache.spark"))
|
||||
if (internalSources.size == 1) {
|
||||
logWarning(s"Multiple sources found for $provider1 (${sourceNames.mkString(", ")}), " +
|
||||
s"defaulting to the internal datasource (${internalSources.head.getClass.getName}).")
|
||||
internalSources.head.getClass
|
||||
} else {
|
||||
throw new AnalysisException(s"Multiple sources found for $provider1 " +
|
||||
s"(${sourceNames.mkString(", ")}), please specify the fully qualified class name.")
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] =>
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
org.apache.spark.sql.sources.FakeSourceOne
|
||||
org.apache.spark.sql.sources.FakeSourceTwo
|
||||
org.apache.spark.sql.sources.FakeSourceThree
|
||||
org.apache.spark.sql.sources.FakeSourceFour
|
||||
org.apache.fakesource.FakeExternalSourceOne
|
||||
org.apache.fakesource.FakeExternalSourceTwo
|
||||
org.apache.fakesource.FakeExternalSourceThree
|
||||
|
|
|
@ -19,26 +19,39 @@ package org.apache.spark.sql.sources
|
|||
|
||||
import org.apache.spark.sql.{AnalysisException, SQLContext}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{StringType, StructField, StructType}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
// please note that the META-INF/services had to be modified for the test directory for this to work
|
||||
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {
|
||||
|
||||
test("data sources with the same name") {
|
||||
intercept[RuntimeException] {
|
||||
test("data sources with the same name - internal data sources") {
|
||||
val e = intercept[AnalysisException] {
|
||||
spark.read.format("Fluet da Bomb").load()
|
||||
}
|
||||
assert(e.getMessage.contains("Multiple sources found for Fluet da Bomb"))
|
||||
}
|
||||
|
||||
test("data sources with the same name - internal data source/external data source") {
|
||||
assert(spark.read.format("datasource").load().schema ==
|
||||
StructType(Seq(StructField("longType", LongType, nullable = false))))
|
||||
}
|
||||
|
||||
test("data sources with the same name - external data sources") {
|
||||
val e = intercept[AnalysisException] {
|
||||
spark.read.format("Fake external source").load()
|
||||
}
|
||||
assert(e.getMessage.contains("Multiple sources found for Fake external source"))
|
||||
}
|
||||
|
||||
test("load data source from format alias") {
|
||||
spark.read.format("gathering quorum").load().schema ==
|
||||
StructType(Seq(StructField("stringType", StringType, nullable = false)))
|
||||
assert(spark.read.format("gathering quorum").load().schema ==
|
||||
StructType(Seq(StructField("stringType", StringType, nullable = false))))
|
||||
}
|
||||
|
||||
test("specify full classname with duplicate formats") {
|
||||
spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
|
||||
.load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
|
||||
assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
|
||||
.load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))))
|
||||
}
|
||||
|
||||
test("should fail to load ORC without Hive Support") {
|
||||
|
@ -63,7 +76,7 @@ class FakeSourceOne extends RelationProvider with DataSourceRegister {
|
|||
}
|
||||
}
|
||||
|
||||
class FakeSourceTwo extends RelationProvider with DataSourceRegister {
|
||||
class FakeSourceTwo extends RelationProvider with DataSourceRegister {
|
||||
|
||||
def shortName(): String = "Fluet da Bomb"
|
||||
|
||||
|
@ -72,7 +85,7 @@ class FakeSourceTwo extends RelationProvider with DataSourceRegister {
|
|||
override def sqlContext: SQLContext = cont
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(Seq(StructField("stringType", StringType, nullable = false)))
|
||||
StructType(Seq(StructField("integerType", IntegerType, nullable = false)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,3 +101,16 @@ class FakeSourceThree extends RelationProvider with DataSourceRegister {
|
|||
StructType(Seq(StructField("stringType", StringType, nullable = false)))
|
||||
}
|
||||
}
|
||||
|
||||
class FakeSourceFour extends RelationProvider with DataSourceRegister {
|
||||
|
||||
def shortName(): String = "datasource"
|
||||
|
||||
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
|
||||
new BaseRelation {
|
||||
override def sqlContext: SQLContext = cont
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(Seq(StructField("longType", LongType, nullable = false)))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.fakesource
|
||||
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
// Note that the package name is intendedly mismatched in order to resemble external data sources
|
||||
// and test the detection for them.
|
||||
class FakeExternalSourceOne extends RelationProvider with DataSourceRegister {
|
||||
|
||||
def shortName(): String = "Fake external source"
|
||||
|
||||
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
|
||||
new BaseRelation {
|
||||
override def sqlContext: SQLContext = cont
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(Seq(StructField("stringType", StringType, nullable = false)))
|
||||
}
|
||||
}
|
||||
|
||||
class FakeExternalSourceTwo extends RelationProvider with DataSourceRegister {
|
||||
|
||||
def shortName(): String = "Fake external source"
|
||||
|
||||
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
|
||||
new BaseRelation {
|
||||
override def sqlContext: SQLContext = cont
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(Seq(StructField("integerType", IntegerType, nullable = false)))
|
||||
}
|
||||
}
|
||||
|
||||
class FakeExternalSourceThree extends RelationProvider with DataSourceRegister {
|
||||
|
||||
def shortName(): String = "datasource"
|
||||
|
||||
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
|
||||
new BaseRelation {
|
||||
override def sqlContext: SQLContext = cont
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(Seq(StructField("byteType", ByteType, nullable = false)))
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue