[SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes
Due to a recent change that made `StructType` a `Seq` we started inadvertently turning `StructType`s into generic `Traversable` when attempting nested tree transformations. In this PR we explicitly avoid descending into `DataType`s to avoid this bug. Author: Michael Armbrust <michael@databricks.com> Closes #5157 from marmbrus/udfFix and squashes the following commits: 26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold StructTypes
This commit is contained in:
parent
26c6ce3d29
commit
3fa3d121df
|
@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
|
|||
case e: Expression => transformExpressionDown(e)
|
||||
case Some(e: Expression) => Some(transformExpressionDown(e))
|
||||
case m: Map[_,_] => m
|
||||
case d: DataType => d // Avoid unpacking Structs
|
||||
case seq: Traversable[_] => seq.map {
|
||||
case e: Expression => transformExpressionDown(e)
|
||||
case other => other
|
||||
|
@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
|
|||
case e: Expression => transformExpressionUp(e)
|
||||
case Some(e: Expression) => Some(transformExpressionUp(e))
|
||||
case m: Map[_,_] => m
|
||||
case d: DataType => d // Avoid unpacking Structs
|
||||
case seq: Traversable[_] => seq.map {
|
||||
case e: Expression => transformExpressionUp(e)
|
||||
case other => other
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.trees
|
||||
|
||||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
|
||||
private class MutableInt(var i: Int)
|
||||
|
@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
|
|||
Some(arg)
|
||||
}
|
||||
case m: Map[_,_] => m
|
||||
case d: DataType => d // Avoid unpacking Structs
|
||||
case args: Traversable[_] => args.map {
|
||||
case arg: TreeNode[_] if children contains arg =>
|
||||
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
|
||||
|
@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
|
|||
Some(arg)
|
||||
}
|
||||
case m: Map[_,_] => m
|
||||
case d: DataType => d // Avoid unpacking Structs
|
||||
case args: Traversable[_] => args.map {
|
||||
case arg: TreeNode[_] if children contains arg =>
|
||||
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
|
||||
|
@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
|
|||
* @param newArgs the new product arguments.
|
||||
*/
|
||||
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
|
||||
val defaultCtor =
|
||||
getClass.getConstructors
|
||||
.find(_.getParameterTypes.size != 0)
|
||||
.headOption
|
||||
.getOrElse(sys.error(s"No valid constructor for $nodeName"))
|
||||
|
||||
try {
|
||||
CurrentOrigin.withOrigin(origin) {
|
||||
// Skip no-arg constructors that are just there for kryo.
|
||||
val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
|
||||
if (otherCopyArgs.isEmpty) {
|
||||
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
|
||||
} else {
|
||||
|
@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
|
|||
} catch {
|
||||
case e: java.lang.IllegalArgumentException =>
|
||||
throw new TreeNodeException(
|
||||
this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? "
|
||||
+ s"Exception message: ${e.getMessage}.")
|
||||
this,
|
||||
s"""
|
||||
|Failed to copy node.
|
||||
|Is otherCopyArgs specified correctly for $nodeName.
|
||||
|Exception message: ${e.getMessage}
|
||||
|ctor: $defaultCtor?
|
||||
|args: ${newArgs.mkString(", ")}
|
||||
""".stripMargin)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
|
|||
.select($"ret.f1").head().getString(0)
|
||||
assert(result === "test")
|
||||
}
|
||||
|
||||
test("udf that is transformed") {
|
||||
udf.register("makeStruct", (x: Int, y: Int) => (x, y))
|
||||
// 1 + 1 is constant folded causing a transformation.
|
||||
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue