[SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes

Due to a recent change that made `StructType` a `Seq` we started inadvertently turning `StructType`s into generic `Traversable` when attempting nested tree transformations.  In this PR we explicitly avoid descending into `DataType`s to avoid this bug.

Author: Michael Armbrust <michael@databricks.com>

Closes #5157 from marmbrus/udfFix and squashes the following commits:

26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold StructTypes
This commit is contained in:
Michael Armbrust 2015-03-24 12:28:01 -07:00
parent 26c6ce3d29
commit 3fa3d121df
3 changed files with 25 additions and 3 deletions

View file

@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionDown(e)
case other => other
@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e))
case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionUp(e)
case other => other

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.types.DataType
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
val defaultCtor =
getClass.getConstructors
.find(_.getParameterTypes.size != 0)
.headOption
.getOrElse(sys.error(s"No valid constructor for $nodeName"))
try {
CurrentOrigin.withOrigin(origin) {
// Skip no-arg constructors that are just there for kryo.
val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
if (otherCopyArgs.isEmpty) {
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
} else {
@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
} catch {
case e: java.lang.IllegalArgumentException =>
throw new TreeNodeException(
this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? "
+ s"Exception message: ${e.getMessage}.")
this,
s"""
|Failed to copy node.
|Is otherCopyArgs specified correctly for $nodeName.
|Exception message: ${e.getMessage}
|ctor: $defaultCtor?
|args: ${newArgs.mkString(", ")}
""".stripMargin)
}
}

View file

@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
.select($"ret.f1").head().getString(0)
assert(result === "test")
}
test("udf that is transformed") {
udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
}