[SPARK-5873][SQL] Allow viewing of partially analyzed plans in queryExecution

Author: Michael Armbrust <michael@databricks.com>

Closes #4684 from marmbrus/explainAnalysis and squashes the following commits:

afbaa19 [Michael Armbrust] fix python
d93278c [Michael Armbrust] fix hive
e5fa0a4 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis
52119f2 [Michael Armbrust] more tests
82a5431 [Michael Armbrust] fix tests
25753d2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis
aee1e6a [Michael Armbrust] fix hive
b23a844 [Michael Armbrust] newline
de8dc51 [Michael Armbrust] more comments
acf620a [Michael Armbrust] [SPARK-5873][SQL] Show partially analyzed plans in query execution
This commit is contained in:
Michael Armbrust 2015-02-23 17:34:54 -08:00
parent 48376bfe9c
commit 1ed57086d4
12 changed files with 164 additions and 126 deletions

View file

@ -267,20 +267,20 @@ class SQLContext(object):
... StructField("byte2", ByteType(), False),
... StructField("short1", ShortType(), False),
... StructField("short2", ShortType(), False),
... StructField("int", IntegerType(), False),
... StructField("float", FloatType(), False),
... StructField("date", DateType(), False),
... StructField("time", TimestampType(), False),
... StructField("map",
... StructField("int1", IntegerType(), False),
... StructField("float1", FloatType(), False),
... StructField("date1", DateType(), False),
... StructField("time1", TimestampType(), False),
... StructField("map1",
... MapType(StringType(), IntegerType(), False), False),
... StructField("struct",
... StructField("struct1",
... StructType([StructField("b", ShortType(), False)]), False),
... StructField("list", ArrayType(ByteType(), False), False),
... StructField("null", DoubleType(), True)])
... StructField("list1", ArrayType(ByteType(), False), False),
... StructField("null1", DoubleType(), True)])
>>> df = sqlCtx.applySchema(rdd, schema)
>>> results = df.map(
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
... x.time, x.map["a"], x.struct.b, x.list, x.null))
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
@ -288,20 +288,20 @@ class SQLContext(object):
>>> df.registerTempTable("table2")
>>> sqlCtx.sql(
... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
... "float + 1.5 as float FROM table2").collect()
[Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)]
... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
... "float1 + 1.5 as float1 FROM table2").collect()
[Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)]
>>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
>>> rdd = sc.parallelize([(127, -32768, 1.0,
... datetime(2010, 1, 1, 1, 1, 1),
... {"a": 1}, (2,), [1, 2, 3])])
>>> abstract = "byte short float time map{} struct(b) list[]"
>>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
>>> schema = _parse_schema_abstract(abstract)
>>> typedSchema = _infer_schema_type(rdd.first(), schema)
>>> df = sqlCtx.applySchema(rdd, typedSchema)
>>> df.collect()
[Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
[Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])]
"""
if isinstance(rdd, DataFrame):

View file

@ -78,6 +78,7 @@ class SqlParser extends AbstractSparkSQLParser {
protected val IF = Keyword("IF")
protected val IN = Keyword("IN")
protected val INNER = Keyword("INNER")
protected val INT = Keyword("INT")
protected val INSERT = Keyword("INSERT")
protected val INTERSECT = Keyword("INTERSECT")
protected val INTO = Keyword("INTO")
@ -394,6 +395,7 @@ class SqlParser extends AbstractSparkSQLParser {
| fixedDecimalType
| DECIMAL ^^^ DecimalType.Unlimited
| DATE ^^^ DateType
| INT ^^^ IntegerType
)
protected lazy val fixedDecimalType: Parser[DataType] =

View file

@ -52,12 +52,6 @@ class Analyzer(catalog: Catalog,
*/
val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
/**
* Override to provide additional rules for the "Check Analysis" batch.
* These rules will be evaluated after our built-in check rules.
*/
val extendedCheckRules: Seq[Rule[LogicalPlan]] = Nil
lazy val batches: Seq[Batch] = Seq(
Batch("Resolution", fixedPoint,
ResolveRelations ::
@ -71,87 +65,10 @@ class Analyzer(catalog: Catalog,
TrimGroupingAliases ::
typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Check Analysis", Once,
CheckResolution +:
extendedCheckRules: _*),
Batch("Remove SubQueries", fixedPoint,
EliminateSubQueries)
)
/**
* Makes sure all attributes and logical plans have been resolved.
*/
object CheckResolution extends Rule[LogicalPlan] {
def failAnalysis(msg: String) = { throw new AnalysisException(msg) }
def apply(plan: LogicalPlan): LogicalPlan = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
case c: Cast if !c.resolved =>
failAnalysis(
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
case b: BinaryExpression if !b.resolved =>
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.simpleString} and ${b.right.simpleString}")
}
operator match {
case f: Filter if f.condition.dataType != BooleanType =>
failAnalysis(
s"filter expression '${f.condition.prettyString}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.contains(e) => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
val cleaned = aggregateExprs.map(_.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g, _) => g
})
cleaned.foreach(checkValidAggregateExpression)
case o if o.children.nonEmpty &&
!o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
val input = o.inputSet.map(_.prettyString).mkString(",")
failAnalysis(s"resolved attributes $missingAttributes missing from $input")
// Catch all
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
case _ => // Analysis successful!
}
}
plan
}
}
/**
* Removes no-op Alias expressions from the plan.
*/

View file

@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
/**
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
class CheckAnalysis {
/**
* Override to provide additional checks for correct analysis.
* These rules will be evaluated after our built-in check rules.
*/
val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
def failAnalysis(msg: String) = {
throw new AnalysisException(msg)
}
def apply(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
case c: Cast if !c.resolved =>
failAnalysis(
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
case b: BinaryExpression if !b.resolved =>
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.simpleString} and ${b.right.simpleString}")
}
operator match {
case f: Filter if f.condition.dataType != BooleanType =>
failAnalysis(
s"filter expression '${f.condition.prettyString}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.contains(e) => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
val cleaned = aggregateExprs.map(_.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g, _) => g
})
cleaned.foreach(checkValidAggregateExpression)
case o if o.children.nonEmpty &&
!o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
val input = o.inputSet.map(_.prettyString).mkString(",")
failAnalysis(s"resolved attributes $missingAttributes missing from $input")
// Catch all
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
case _ => // Analysis successful!
}
}
extendedCheckRules.foreach(_(plan))
}
}

View file

@ -30,11 +30,21 @@ import org.apache.spark.sql.catalyst.dsl.plans._
class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
val caseInsensitiveCatalog = new SimpleCatalog(false)
val caseSensitiveAnalyze =
val caseSensitiveAnalyzer =
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
val caseInsensitiveAnalyze =
val caseInsensitiveAnalyzer =
new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
val checkAnalysis = new CheckAnalysis
def caseSensitiveAnalyze(plan: LogicalPlan) =
checkAnalysis(caseSensitiveAnalyzer(plan))
def caseInsensitiveAnalyze(plan: LogicalPlan) =
checkAnalysis(caseInsensitiveAnalyzer(plan))
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
@ -55,7 +65,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
}
assert(caseInsensitiveAnalyze(plan).resolved)
assert(caseInsensitiveAnalyzer(plan).resolved)
}
test("check project's resolved") {
@ -71,11 +81,11 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
test("analyze project") {
assert(
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
Project(testRelation.output, testRelation))
assert(
caseSensitiveAnalyze(
caseSensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
@ -88,13 +98,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(e.getMessage().toLowerCase.contains("cannot resolve"))
assert(
caseInsensitiveAnalyze(
caseInsensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
assert(
caseInsensitiveAnalyze(
caseInsensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
@ -107,16 +117,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(e.getMessage == "Table Not Found: tAbLe")
assert(
caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
testRelation)
caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
assert(
caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) ===
testRelation)
caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
assert(
caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
testRelation)
caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
}
def errorTest(
@ -177,7 +184,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
val plan = caseInsensitiveAnalyze(
val plan = caseInsensitiveAnalyzer(
testRelation2.select(
'a / Literal(2) as 'div1,
'a / 'b as 'div2,

View file

@ -117,7 +117,7 @@ class DataFrame protected[sql](
this(sqlContext, {
val qe = sqlContext.executePlan(logicalPlan)
if (sqlContext.conf.dataFrameEagerAnalysis) {
qe.analyzed // This should force analysis and throw errors if there are any
qe.assertAnalyzed() // This should force analysis and throw errors if there are any
}
qe
})

View file

@ -52,8 +52,9 @@ private[spark] object SQLConf {
// This is used to set the default data source
val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default"
// Whether to perform eager analysis on a DataFrame.
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis"
// Whether to perform eager analysis when constructing a dataframe.
// Set to false when debugging requires the ability to look at invalid query plans.
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"

View file

@ -114,7 +114,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
new Analyzer(catalog, functionRegistry, caseSensitive = true) {
override val extendedResolutionRules =
ExtractPythonUdfs ::
sources.PreWriteCheck(catalog) ::
sources.PreInsertCastAndRename ::
Nil
}
@ -1057,6 +1056,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
Batch("Add exchange", Once, AddExchange(self)) :: Nil
}
@transient
protected[sql] lazy val checkAnalysis = new CheckAnalysis {
override val extendedCheckRules = Seq(
sources.PreWriteCheck(catalog)
)
}
/**
* :: DeveloperApi ::
* The primary workflow for executing relational queries using Spark. Designed to allow easy
@ -1064,9 +1070,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@DeveloperApi
protected[sql] class QueryExecution(val logical: LogicalPlan) {
def assertAnalyzed(): Unit = checkAnalysis(analyzed)
lazy val analyzed: LogicalPlan = analyzer(logical)
lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed)
lazy val withCachedData: LogicalPlan = {
assertAnalyzed
cacheManager.useCachedData(analyzed)
}
lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
// TODO: Don't just pick the first one...

View file

@ -78,10 +78,10 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
/**
* A rule to do various checks before inserting into or writing to a data source table.
*/
private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan] {
private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) {
def failAnalysis(msg: String) = { throw new AnalysisException(msg) }
def apply(plan: LogicalPlan): LogicalPlan = {
def apply(plan: LogicalPlan): Unit = {
plan.foreach {
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) =>
@ -93,7 +93,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
val srcRelations = query.collect {
case LogicalRelation(src: BaseRelation) => src
}
if (srcRelations.exists(src => src == t)) {
if (srcRelations.contains(t)) {
failAnalysis(
"Cannot insert overwrite into table that is also being read from.")
} else {
@ -119,7 +119,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
val srcRelations = query.collect {
case LogicalRelation(src: BaseRelation) => src
}
if (srcRelations.exists(src => src == dest)) {
if (srcRelations.contains(dest)) {
failAnalysis(
s"Cannot overwrite table $tableName that is also being read from.")
} else {
@ -134,7 +134,5 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
case _ => // OK
}
plan
}
}

View file

@ -30,7 +30,6 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
override protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = false) {
override val extendedResolutionRules =
PreWriteCheck(catalog) ::
PreInsertCastAndRename ::
Nil
}

View file

@ -205,7 +205,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
val message = intercept[AnalysisException] {
sql(
s"""
|INSERT OVERWRITE TABLE oneToTen SELECT a FROM jt
|INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt
""".stripMargin)
}.getMessage
assert(

View file

@ -268,7 +268,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.PreInsertionCasts ::
ExtractPythonUdfs ::
ResolveUdtfsAlias ::
sources.PreWriteCheck(catalog) ::
sources.PreInsertCastAndRename ::
Nil
}