[SPARK-5873][SQL] Allow viewing of partially analyzed plans in queryExecution
Author: Michael Armbrust <michael@databricks.com> Closes #4684 from marmbrus/explainAnalysis and squashes the following commits: afbaa19 [Michael Armbrust] fix python d93278c [Michael Armbrust] fix hive e5fa0a4 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis 52119f2 [Michael Armbrust] more tests 82a5431 [Michael Armbrust] fix tests 25753d2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis aee1e6a [Michael Armbrust] fix hive b23a844 [Michael Armbrust] newline de8dc51 [Michael Armbrust] more comments acf620a [Michael Armbrust] [SPARK-5873][SQL] Show partially analyzed plans in query execution
This commit is contained in:
parent
48376bfe9c
commit
1ed57086d4
|
@ -267,20 +267,20 @@ class SQLContext(object):
|
|||
... StructField("byte2", ByteType(), False),
|
||||
... StructField("short1", ShortType(), False),
|
||||
... StructField("short2", ShortType(), False),
|
||||
... StructField("int", IntegerType(), False),
|
||||
... StructField("float", FloatType(), False),
|
||||
... StructField("date", DateType(), False),
|
||||
... StructField("time", TimestampType(), False),
|
||||
... StructField("map",
|
||||
... StructField("int1", IntegerType(), False),
|
||||
... StructField("float1", FloatType(), False),
|
||||
... StructField("date1", DateType(), False),
|
||||
... StructField("time1", TimestampType(), False),
|
||||
... StructField("map1",
|
||||
... MapType(StringType(), IntegerType(), False), False),
|
||||
... StructField("struct",
|
||||
... StructField("struct1",
|
||||
... StructType([StructField("b", ShortType(), False)]), False),
|
||||
... StructField("list", ArrayType(ByteType(), False), False),
|
||||
... StructField("null", DoubleType(), True)])
|
||||
... StructField("list1", ArrayType(ByteType(), False), False),
|
||||
... StructField("null1", DoubleType(), True)])
|
||||
>>> df = sqlCtx.applySchema(rdd, schema)
|
||||
>>> results = df.map(
|
||||
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
|
||||
... x.time, x.map["a"], x.struct.b, x.list, x.null))
|
||||
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
|
||||
... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
|
||||
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
|
||||
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
|
||||
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
|
||||
|
@ -288,20 +288,20 @@ class SQLContext(object):
|
|||
>>> df.registerTempTable("table2")
|
||||
>>> sqlCtx.sql(
|
||||
... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
|
||||
... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
|
||||
... "float + 1.5 as float FROM table2").collect()
|
||||
[Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)]
|
||||
... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
|
||||
... "float1 + 1.5 as float1 FROM table2").collect()
|
||||
[Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)]
|
||||
|
||||
>>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
|
||||
>>> rdd = sc.parallelize([(127, -32768, 1.0,
|
||||
... datetime(2010, 1, 1, 1, 1, 1),
|
||||
... {"a": 1}, (2,), [1, 2, 3])])
|
||||
>>> abstract = "byte short float time map{} struct(b) list[]"
|
||||
>>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
|
||||
>>> schema = _parse_schema_abstract(abstract)
|
||||
>>> typedSchema = _infer_schema_type(rdd.first(), schema)
|
||||
>>> df = sqlCtx.applySchema(rdd, typedSchema)
|
||||
>>> df.collect()
|
||||
[Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
|
||||
[Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])]
|
||||
"""
|
||||
|
||||
if isinstance(rdd, DataFrame):
|
||||
|
|
|
@ -78,6 +78,7 @@ class SqlParser extends AbstractSparkSQLParser {
|
|||
protected val IF = Keyword("IF")
|
||||
protected val IN = Keyword("IN")
|
||||
protected val INNER = Keyword("INNER")
|
||||
protected val INT = Keyword("INT")
|
||||
protected val INSERT = Keyword("INSERT")
|
||||
protected val INTERSECT = Keyword("INTERSECT")
|
||||
protected val INTO = Keyword("INTO")
|
||||
|
@ -394,6 +395,7 @@ class SqlParser extends AbstractSparkSQLParser {
|
|||
| fixedDecimalType
|
||||
| DECIMAL ^^^ DecimalType.Unlimited
|
||||
| DATE ^^^ DateType
|
||||
| INT ^^^ IntegerType
|
||||
)
|
||||
|
||||
protected lazy val fixedDecimalType: Parser[DataType] =
|
||||
|
|
|
@ -52,12 +52,6 @@ class Analyzer(catalog: Catalog,
|
|||
*/
|
||||
val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
|
||||
|
||||
/**
|
||||
* Override to provide additional rules for the "Check Analysis" batch.
|
||||
* These rules will be evaluated after our built-in check rules.
|
||||
*/
|
||||
val extendedCheckRules: Seq[Rule[LogicalPlan]] = Nil
|
||||
|
||||
lazy val batches: Seq[Batch] = Seq(
|
||||
Batch("Resolution", fixedPoint,
|
||||
ResolveRelations ::
|
||||
|
@ -71,87 +65,10 @@ class Analyzer(catalog: Catalog,
|
|||
TrimGroupingAliases ::
|
||||
typeCoercionRules ++
|
||||
extendedResolutionRules : _*),
|
||||
Batch("Check Analysis", Once,
|
||||
CheckResolution +:
|
||||
extendedCheckRules: _*),
|
||||
Batch("Remove SubQueries", fixedPoint,
|
||||
EliminateSubQueries)
|
||||
)
|
||||
|
||||
/**
|
||||
* Makes sure all attributes and logical plans have been resolved.
|
||||
*/
|
||||
object CheckResolution extends Rule[LogicalPlan] {
|
||||
def failAnalysis(msg: String) = { throw new AnalysisException(msg) }
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
// We transform up and order the rules so as to catch the first possible failure instead
|
||||
// of the result of cascading resolution failures.
|
||||
plan.foreachUp {
|
||||
case operator: LogicalPlan =>
|
||||
operator transformExpressionsUp {
|
||||
case a: Attribute if !a.resolved =>
|
||||
val from = operator.inputSet.map(_.name).mkString(", ")
|
||||
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
|
||||
|
||||
case c: Cast if !c.resolved =>
|
||||
failAnalysis(
|
||||
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
|
||||
|
||||
case b: BinaryExpression if !b.resolved =>
|
||||
failAnalysis(
|
||||
s"invalid expression ${b.prettyString} " +
|
||||
s"between ${b.left.simpleString} and ${b.right.simpleString}")
|
||||
}
|
||||
|
||||
operator match {
|
||||
case f: Filter if f.condition.dataType != BooleanType =>
|
||||
failAnalysis(
|
||||
s"filter expression '${f.condition.prettyString}' " +
|
||||
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
|
||||
|
||||
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
|
||||
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
|
||||
case _: AggregateExpression => // OK
|
||||
case e: Attribute if !groupingExprs.contains(e) =>
|
||||
failAnalysis(
|
||||
s"expression '${e.prettyString}' is neither present in the group by, " +
|
||||
s"nor is it an aggregate function. " +
|
||||
"Add to group by or wrap in first() if you don't care which value you get.")
|
||||
case e if groupingExprs.contains(e) => // OK
|
||||
case e if e.references.isEmpty => // OK
|
||||
case e => e.children.foreach(checkValidAggregateExpression)
|
||||
}
|
||||
|
||||
val cleaned = aggregateExprs.map(_.transform {
|
||||
// Should trim aliases around `GetField`s. These aliases are introduced while
|
||||
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
|
||||
// (Should we just turn `GetField` into a `NamedExpression`?)
|
||||
case Alias(g, _) => g
|
||||
})
|
||||
|
||||
cleaned.foreach(checkValidAggregateExpression)
|
||||
|
||||
case o if o.children.nonEmpty &&
|
||||
!o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
|
||||
val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
|
||||
val input = o.inputSet.map(_.prettyString).mkString(",")
|
||||
|
||||
failAnalysis(s"resolved attributes $missingAttributes missing from $input")
|
||||
|
||||
// Catch all
|
||||
case o if !o.resolved =>
|
||||
failAnalysis(
|
||||
s"unresolved operator ${operator.simpleString}")
|
||||
|
||||
case _ => // Analysis successful!
|
||||
}
|
||||
}
|
||||
|
||||
plan
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes no-op Alias expressions from the plan.
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
* Throws user facing errors when passed invalid queries that fail to analyze.
|
||||
*/
|
||||
class CheckAnalysis {
|
||||
|
||||
/**
|
||||
* Override to provide additional checks for correct analysis.
|
||||
* These rules will be evaluated after our built-in check rules.
|
||||
*/
|
||||
val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
|
||||
|
||||
def failAnalysis(msg: String) = {
|
||||
throw new AnalysisException(msg)
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): Unit = {
|
||||
// We transform up and order the rules so as to catch the first possible failure instead
|
||||
// of the result of cascading resolution failures.
|
||||
plan.foreachUp {
|
||||
case operator: LogicalPlan =>
|
||||
operator transformExpressionsUp {
|
||||
case a: Attribute if !a.resolved =>
|
||||
val from = operator.inputSet.map(_.name).mkString(", ")
|
||||
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
|
||||
|
||||
case c: Cast if !c.resolved =>
|
||||
failAnalysis(
|
||||
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
|
||||
|
||||
case b: BinaryExpression if !b.resolved =>
|
||||
failAnalysis(
|
||||
s"invalid expression ${b.prettyString} " +
|
||||
s"between ${b.left.simpleString} and ${b.right.simpleString}")
|
||||
}
|
||||
|
||||
operator match {
|
||||
case f: Filter if f.condition.dataType != BooleanType =>
|
||||
failAnalysis(
|
||||
s"filter expression '${f.condition.prettyString}' " +
|
||||
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
|
||||
|
||||
case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) =>
|
||||
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
|
||||
case _: AggregateExpression => // OK
|
||||
case e: Attribute if !groupingExprs.contains(e) =>
|
||||
failAnalysis(
|
||||
s"expression '${e.prettyString}' is neither present in the group by, " +
|
||||
s"nor is it an aggregate function. " +
|
||||
"Add to group by or wrap in first() if you don't care which value you get.")
|
||||
case e if groupingExprs.contains(e) => // OK
|
||||
case e if e.references.isEmpty => // OK
|
||||
case e => e.children.foreach(checkValidAggregateExpression)
|
||||
}
|
||||
|
||||
val cleaned = aggregateExprs.map(_.transform {
|
||||
// Should trim aliases around `GetField`s. These aliases are introduced while
|
||||
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
|
||||
// (Should we just turn `GetField` into a `NamedExpression`?)
|
||||
case Alias(g, _) => g
|
||||
})
|
||||
|
||||
cleaned.foreach(checkValidAggregateExpression)
|
||||
|
||||
case o if o.children.nonEmpty &&
|
||||
!o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
|
||||
val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
|
||||
val input = o.inputSet.map(_.prettyString).mkString(",")
|
||||
|
||||
failAnalysis(s"resolved attributes $missingAttributes missing from $input")
|
||||
|
||||
// Catch all
|
||||
case o if !o.resolved =>
|
||||
failAnalysis(
|
||||
s"unresolved operator ${operator.simpleString}")
|
||||
|
||||
case _ => // Analysis successful!
|
||||
}
|
||||
}
|
||||
extendedCheckRules.foreach(_(plan))
|
||||
}
|
||||
}
|
|
@ -30,11 +30,21 @@ import org.apache.spark.sql.catalyst.dsl.plans._
|
|||
class AnalysisSuite extends FunSuite with BeforeAndAfter {
|
||||
val caseSensitiveCatalog = new SimpleCatalog(true)
|
||||
val caseInsensitiveCatalog = new SimpleCatalog(false)
|
||||
val caseSensitiveAnalyze =
|
||||
|
||||
val caseSensitiveAnalyzer =
|
||||
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
|
||||
val caseInsensitiveAnalyze =
|
||||
val caseInsensitiveAnalyzer =
|
||||
new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
|
||||
|
||||
val checkAnalysis = new CheckAnalysis
|
||||
|
||||
|
||||
def caseSensitiveAnalyze(plan: LogicalPlan) =
|
||||
checkAnalysis(caseSensitiveAnalyzer(plan))
|
||||
|
||||
def caseInsensitiveAnalyze(plan: LogicalPlan) =
|
||||
checkAnalysis(caseInsensitiveAnalyzer(plan))
|
||||
|
||||
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
|
||||
val testRelation2 = LocalRelation(
|
||||
AttributeReference("a", StringType)(),
|
||||
|
@ -55,7 +65,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
|
|||
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
|
||||
}
|
||||
|
||||
assert(caseInsensitiveAnalyze(plan).resolved)
|
||||
assert(caseInsensitiveAnalyzer(plan).resolved)
|
||||
}
|
||||
|
||||
test("check project's resolved") {
|
||||
|
@ -71,11 +81,11 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
|
|||
|
||||
test("analyze project") {
|
||||
assert(
|
||||
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
|
||||
caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
|
||||
Project(testRelation.output, testRelation))
|
||||
|
||||
assert(
|
||||
caseSensitiveAnalyze(
|
||||
caseSensitiveAnalyzer(
|
||||
Project(Seq(UnresolvedAttribute("TbL.a")),
|
||||
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
|
||||
Project(testRelation.output, testRelation))
|
||||
|
@ -88,13 +98,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
|
|||
assert(e.getMessage().toLowerCase.contains("cannot resolve"))
|
||||
|
||||
assert(
|
||||
caseInsensitiveAnalyze(
|
||||
caseInsensitiveAnalyzer(
|
||||
Project(Seq(UnresolvedAttribute("TbL.a")),
|
||||
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
|
||||
Project(testRelation.output, testRelation))
|
||||
|
||||
assert(
|
||||
caseInsensitiveAnalyze(
|
||||
caseInsensitiveAnalyzer(
|
||||
Project(Seq(UnresolvedAttribute("tBl.a")),
|
||||
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
|
||||
Project(testRelation.output, testRelation))
|
||||
|
@ -107,16 +117,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
|
|||
assert(e.getMessage == "Table Not Found: tAbLe")
|
||||
|
||||
assert(
|
||||
caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
|
||||
testRelation)
|
||||
caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
|
||||
|
||||
assert(
|
||||
caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) ===
|
||||
testRelation)
|
||||
caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
|
||||
|
||||
assert(
|
||||
caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
|
||||
testRelation)
|
||||
caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
|
||||
}
|
||||
|
||||
def errorTest(
|
||||
|
@ -177,7 +184,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
|
|||
AttributeReference("d", DecimalType.Unlimited)(),
|
||||
AttributeReference("e", ShortType)())
|
||||
|
||||
val plan = caseInsensitiveAnalyze(
|
||||
val plan = caseInsensitiveAnalyzer(
|
||||
testRelation2.select(
|
||||
'a / Literal(2) as 'div1,
|
||||
'a / 'b as 'div2,
|
||||
|
|
|
@ -117,7 +117,7 @@ class DataFrame protected[sql](
|
|||
this(sqlContext, {
|
||||
val qe = sqlContext.executePlan(logicalPlan)
|
||||
if (sqlContext.conf.dataFrameEagerAnalysis) {
|
||||
qe.analyzed // This should force analysis and throw errors if there are any
|
||||
qe.assertAnalyzed() // This should force analysis and throw errors if there are any
|
||||
}
|
||||
qe
|
||||
})
|
||||
|
|
|
@ -52,8 +52,9 @@ private[spark] object SQLConf {
|
|||
// This is used to set the default data source
|
||||
val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default"
|
||||
|
||||
// Whether to perform eager analysis on a DataFrame.
|
||||
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis"
|
||||
// Whether to perform eager analysis when constructing a dataframe.
|
||||
// Set to false when debugging requires the ability to look at invalid query plans.
|
||||
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
|
||||
|
||||
object Deprecated {
|
||||
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
|
||||
|
|
|
@ -114,7 +114,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
new Analyzer(catalog, functionRegistry, caseSensitive = true) {
|
||||
override val extendedResolutionRules =
|
||||
ExtractPythonUdfs ::
|
||||
sources.PreWriteCheck(catalog) ::
|
||||
sources.PreInsertCastAndRename ::
|
||||
Nil
|
||||
}
|
||||
|
@ -1057,6 +1056,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
Batch("Add exchange", Once, AddExchange(self)) :: Nil
|
||||
}
|
||||
|
||||
@transient
|
||||
protected[sql] lazy val checkAnalysis = new CheckAnalysis {
|
||||
override val extendedCheckRules = Seq(
|
||||
sources.PreWriteCheck(catalog)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* The primary workflow for executing relational queries using Spark. Designed to allow easy
|
||||
|
@ -1064,9 +1070,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
*/
|
||||
@DeveloperApi
|
||||
protected[sql] class QueryExecution(val logical: LogicalPlan) {
|
||||
def assertAnalyzed(): Unit = checkAnalysis(analyzed)
|
||||
|
||||
lazy val analyzed: LogicalPlan = analyzer(logical)
|
||||
lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed)
|
||||
lazy val withCachedData: LogicalPlan = {
|
||||
assertAnalyzed
|
||||
cacheManager.useCachedData(analyzed)
|
||||
}
|
||||
lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
|
||||
|
||||
// TODO: Don't just pick the first one...
|
||||
|
|
|
@ -78,10 +78,10 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
|
|||
/**
|
||||
* A rule to do various checks before inserting into or writing to a data source table.
|
||||
*/
|
||||
private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan] {
|
||||
private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) {
|
||||
def failAnalysis(msg: String) = { throw new AnalysisException(msg) }
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
def apply(plan: LogicalPlan): Unit = {
|
||||
plan.foreach {
|
||||
case i @ logical.InsertIntoTable(
|
||||
l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) =>
|
||||
|
@ -93,7 +93,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
|
|||
val srcRelations = query.collect {
|
||||
case LogicalRelation(src: BaseRelation) => src
|
||||
}
|
||||
if (srcRelations.exists(src => src == t)) {
|
||||
if (srcRelations.contains(t)) {
|
||||
failAnalysis(
|
||||
"Cannot insert overwrite into table that is also being read from.")
|
||||
} else {
|
||||
|
@ -119,7 +119,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
|
|||
val srcRelations = query.collect {
|
||||
case LogicalRelation(src: BaseRelation) => src
|
||||
}
|
||||
if (srcRelations.exists(src => src == dest)) {
|
||||
if (srcRelations.contains(dest)) {
|
||||
failAnalysis(
|
||||
s"Cannot overwrite table $tableName that is also being read from.")
|
||||
} else {
|
||||
|
@ -134,7 +134,5 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
|
|||
|
||||
case _ => // OK
|
||||
}
|
||||
|
||||
plan
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,6 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
|
|||
override protected[sql] lazy val analyzer: Analyzer =
|
||||
new Analyzer(catalog, functionRegistry, caseSensitive = false) {
|
||||
override val extendedResolutionRules =
|
||||
PreWriteCheck(catalog) ::
|
||||
PreInsertCastAndRename ::
|
||||
Nil
|
||||
}
|
||||
|
|
|
@ -205,7 +205,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
|
|||
val message = intercept[AnalysisException] {
|
||||
sql(
|
||||
s"""
|
||||
|INSERT OVERWRITE TABLE oneToTen SELECT a FROM jt
|
||||
|INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt
|
||||
""".stripMargin)
|
||||
}.getMessage
|
||||
assert(
|
||||
|
|
|
@ -268,7 +268,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
|
|||
catalog.PreInsertionCasts ::
|
||||
ExtractPythonUdfs ::
|
||||
ResolveUdtfsAlias ::
|
||||
sources.PreWriteCheck(catalog) ::
|
||||
sources.PreInsertCastAndRename ::
|
||||
Nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue