[SPARK-12441][SQL] Fixing missingInput in Generate/MapPartitions/AppendColumns/MapGroups/CoGroup

When explain any plan with Generate, we will see an exclamation mark in the plan. Normally, when we see this mark, it means the plan has an error. This PR is to correct the `missingInput` in `Generate`.

For example,
```scala
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters")
val df2 =
  df.explode('letters) {
    case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
  }

df2.explain(true)
```
Before the fix, the plan is like
```
== Parsed Logical Plan ==
'Generate UserDefinedGenerator('letters), true, false, None
+- Project [_1#0 AS number#2,_2#1 AS letters#3]
   +- LocalRelation [_1#0,_2#1], [[1,a b c],[2,a b],[3,a]]

== Analyzed Logical Plan ==
number: int, letters: string, _1: string
Generate UserDefinedGenerator(letters#3), true, false, None, [_1#8]
+- Project [_1#0 AS number#2,_2#1 AS letters#3]
   +- LocalRelation [_1#0,_2#1], [[1,a b c],[2,a b],[3,a]]

== Optimized Logical Plan ==
Generate UserDefinedGenerator(letters#3), true, false, None, [_1#8]
+- LocalRelation [number#2,letters#3], [[1,a b c],[2,a b],[3,a]]

== Physical Plan ==
!Generate UserDefinedGenerator(letters#3), true, false, [number#2,letters#3,_1#8]
+- LocalTableScan [number#2,letters#3], [[1,a b c],[2,a b],[3,a]]
```

**Updates**: The same issues are also found in the other four Dataset operators: `MapPartitions`/`AppendColumns`/`MapGroups`/`CoGroup`. Fixed all these four.

Author: gatorsmile <gatorsmile@gmail.com>
Author: xiaoli <lixiao1983@gmail.com>
Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local>

Closes #10393 from gatorsmile/generateExplain.
This commit is contained in:
gatorsmile 2015-12-28 12:48:30 -08:00 committed by Michael Armbrust
parent a6a4812434
commit 01ba95d8bf
15 changed files with 63 additions and 18 deletions

View file

@ -43,16 +43,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
def inputSet: AttributeSet =
AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output))
/**
* The set of all attributes that are produced by this node.
*/
def producedAttributes: AttributeSet = AttributeSet.empty
/**
* Attributes that are referenced by expressions but not provided by this nodes children.
* Subclasses should override this method if they produce attributes internally as it is used by
* assertions designed to prevent the construction of invalid plans.
*
* Note that virtual columns should be excluded. Currently, we only support the grouping ID
* virtual column.
*/
def missingInput: AttributeSet =
(references -- inputSet).filter(_.name != VirtualColumn.groupingIdName)
def missingInput: AttributeSet = references -- inputSet -- producedAttributes
/**
* Runs [[transform]] with `rule` on all expressions present in this query operator.

View file

@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.types.{StructField, StructType}
object LocalRelation {

View file

@ -295,6 +295,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
abstract class LeafNode extends LogicalPlan {
override def children: Seq[LogicalPlan] = Nil
override def producedAttributes: AttributeSet = outputSet
}
/**

View file

@ -526,7 +526,7 @@ case class MapPartitions[T, U](
uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def missingInput: AttributeSet = AttributeSet.empty
override def producedAttributes: AttributeSet = outputSet
}
/** Factory for constructing new `AppendColumn` nodes. */
@ -552,7 +552,7 @@ case class AppendColumns[T, U](
newColumns: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ newColumns
override def missingInput: AttributeSet = super.missingInput -- newColumns
override def producedAttributes: AttributeSet = AttributeSet(newColumns)
}
/** Factory for constructing new `MapGroups` nodes. */
@ -587,7 +587,7 @@ case class MapGroups[K, T, U](
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def missingInput: AttributeSet = AttributeSet.empty
override def producedAttributes: AttributeSet = outputSet
}
/** Factory for constructing new `CoGroup` nodes. */
@ -630,5 +630,5 @@ case class CoGroup[Key, Left, Right, Result](
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode {
override def missingInput: AttributeSet = AttributeSet.empty
override def producedAttributes: AttributeSet = outputSet
}

View file

@ -18,11 +18,11 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation}
import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SQLContext}
@ -84,6 +84,8 @@ private[sql] case class LogicalRDD(
case _ => false
}
override def producedAttributes: AttributeSet = outputSet
@transient override lazy val statistics: Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.

View file

@ -54,6 +54,8 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {
override def expressions: Seq[Expression] = generator :: Nil
val boundGenerator = BindReferences.bindReference(generator, child.output)
protected override def doExecute(): RDD[InternalRow] = {

View file

@ -279,6 +279,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
private[sql] trait LeafNode extends SparkPlan {
override def children: Seq[SparkPlan] = Nil
override def producedAttributes: AttributeSet = outputSet
}
private[sql] trait UnaryNode extends SparkPlan {

View file

@ -36,6 +36,15 @@ case class SortBasedAggregate(
child: SparkPlan)
extends UnaryNode {
private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}
override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)
override private[sql] lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))

View file

@ -55,6 +55,11 @@ case class TungstenAggregate(
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.length == 0 => AllTuples :: Nil

View file

@ -369,6 +369,7 @@ case class MapPartitions[T, U](
uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
@ -391,6 +392,7 @@ case class AppendColumns[T, U](
uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(newColumns)
// We are using an unsafe combiner.
override def canProcessSafeRows: Boolean = false
@ -424,6 +426,7 @@ case class MapGroups[K, T, U](
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
@ -467,6 +470,7 @@ case class CoGroup[Key, Left, Right, Result](
rightGroup: Seq[Attribute],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true

View file

@ -66,6 +66,8 @@ private[sql] case class InMemoryRelation(
private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
extends LogicalPlan with MultiInstanceRelation {
override def producedAttributes: AttributeSet = outputSet
private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =
if (_batchStats == null) {
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow])

View file

@ -15,16 +15,14 @@
* limitations under the License.
*/
package test.org.apache.spark.sql
package org.apache.spark.sql
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.UTF8String
case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
@ -34,6 +32,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
sparkContext.parallelize(Seq(row))
}
override def producedAttributes: AttributeSet = outputSet
override def children: Seq[SparkPlan] = Nil
}

View file

@ -130,6 +130,8 @@ abstract class QueryTest extends PlanTest {
checkJsonFormat(analyzedDF)
assertEmptyMissingInput(df)
QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
@ -275,6 +277,18 @@ abstract class QueryTest extends PlanTest {
""".stripMargin)
}
}
/**
* Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans.
*/
def assertEmptyMissingInput(query: Queryable): Unit = {
assert(query.queryExecution.analyzed.missingInput.isEmpty,
s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}")
assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}")
assert(query.queryExecution.executedPlan.missingInput.isEmpty,
s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}")
}
}
object QueryTest {

View file

@ -51,6 +51,9 @@ case class HiveTableScan(
require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
override def producedAttributes: AttributeSet = outputSet ++
AttributeSet(partitionPruningPred.flatMap(_.references))
// Retrieve the original attributes based on expression ID so that capitalization matches.
val attributes = requestedAttributes.map(relation.attributeMap)

View file

@ -60,6 +60,8 @@ case class ScriptTransformation(
override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil
override def producedAttributes: AttributeSet = outputSet -- inputSet
private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
protected override def doExecute(): RDD[InternalRow] = {