[SPARK-35411][SQL][FOLLOWUP] Handle Currying Product while serializing TreeNode to JSON
### What changes were proposed in this pull request? Handle Currying Product while serializing TreeNode to JSON. While processing [Product](https://github.com/apache/spark/blob/v3.1.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L820), we may get an assert error for cases like Currying Product because of the mismatch of sizes between field name and field values. Fallback to use reflection to get all the values for constructor parameters when we meet such cases. ### Why are the changes needed? Avoid throwing error while serializing TreeNode to JSON, try to output as much information as possible. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New UT case added. Closes #32713 from ivoson/SPARK-35411-followup. Authored-by: Tengfei Huang <tengfei.h@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
14e12c64d3
commit
1603775934
|
@ -1030,7 +1030,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
|
||||||
case p: Product if shouldConvertToJson(p) =>
|
case p: Product if shouldConvertToJson(p) =>
|
||||||
try {
|
try {
|
||||||
val fieldNames = getConstructorParameterNames(p.getClass)
|
val fieldNames = getConstructorParameterNames(p.getClass)
|
||||||
val fieldValues = p.productIterator.toSeq
|
val fieldValues = {
|
||||||
|
if (p.productArity == fieldNames.length) {
|
||||||
|
p.productIterator.toSeq
|
||||||
|
} else {
|
||||||
|
val clazz = p.getClass
|
||||||
|
// Fallback to use reflection if length of product elements do not match
|
||||||
|
// constructor params.
|
||||||
|
fieldNames.map { fieldName =>
|
||||||
|
val field = clazz.getDeclaredField(fieldName)
|
||||||
|
field.setAccessible(true)
|
||||||
|
field.get(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: " +
|
assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: " +
|
||||||
fieldNames.mkString(", ") + s", values: " + fieldValues.mkString(", "))
|
fieldNames.mkString(", ") + s", values: " + fieldValues.mkString(", "))
|
||||||
("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
|
("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
|
||||||
|
@ -1038,6 +1051,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
|
||||||
}.toList
|
}.toList
|
||||||
} catch {
|
} catch {
|
||||||
case _: RuntimeException => null
|
case _: RuntimeException => null
|
||||||
|
case _: ReflectiveOperationException => null
|
||||||
}
|
}
|
||||||
case _ => JNull
|
case _ => JNull
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,6 +95,8 @@ case class FakeLeafPlan(child: LogicalPlan)
|
||||||
override def output: Seq[Attribute] = child.output
|
override def output: Seq[Attribute] = child.output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case class FakeCurryingProduct(x: Expression)(val y: Int)
|
||||||
|
|
||||||
class TreeNodeSuite extends SparkFunSuite with SQLHelper {
|
class TreeNodeSuite extends SparkFunSuite with SQLHelper {
|
||||||
test("top node changed") {
|
test("top node changed") {
|
||||||
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
|
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
|
||||||
|
@ -622,6 +624,21 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
|
||||||
"num-children" -> 0,
|
"num-children" -> 0,
|
||||||
"arg" -> "1"
|
"arg" -> "1"
|
||||||
)))))
|
)))))
|
||||||
|
|
||||||
|
// Convert currying product contains TreeNode to JSON.
|
||||||
|
assertJSON(
|
||||||
|
FakeCurryingProduct(Literal(1))(1),
|
||||||
|
JObject(
|
||||||
|
"product-class" -> classOf[FakeCurryingProduct].getName,
|
||||||
|
"x" -> List(
|
||||||
|
JObject(
|
||||||
|
"class" -> JString(classOf[Literal].getName),
|
||||||
|
"num-children" -> 0,
|
||||||
|
"value" -> "1",
|
||||||
|
"dataType" -> "integer")),
|
||||||
|
"y" -> 1
|
||||||
|
)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("toJSON should not throws java.lang.StackOverflowError") {
|
test("toJSON should not throws java.lang.StackOverflowError") {
|
||||||
|
|
Loading…
Reference in a new issue