[SPARK-35411][SQL][FOLLOWUP] Handle Currying Product while serializing TreeNode to JSON

### What changes were proposed in this pull request?
Handle Currying Product while serializing TreeNode to JSON. While processing [Product](https://github.com/apache/spark/blob/v3.1.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L820), we may get an assert error for cases like Currying Product because of the mismatch of sizes between field name and field values.
Fallback to use reflection to get all the values for constructor parameters when we  meet such cases.

### Why are the changes needed?
Avoid throwing error while serializing TreeNode to JSON, try to output as much information as possible.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
New UT case added.

Closes #32713 from ivoson/SPARK-35411-followup.

Authored-by: Tengfei Huang <tengfei.h@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Tengfei Huang 2021-05-31 22:15:26 +08:00 committed by Wenchen Fan
parent 14e12c64d3
commit 1603775934
2 changed files with 32 additions and 1 deletions

View file

@ -1030,7 +1030,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
case p: Product if shouldConvertToJson(p) =>
try {
val fieldNames = getConstructorParameterNames(p.getClass)
val fieldValues = p.productIterator.toSeq
val fieldValues = {
if (p.productArity == fieldNames.length) {
p.productIterator.toSeq
} else {
val clazz = p.getClass
// Fallback to use reflection if length of product elements do not match
// constructor params.
fieldNames.map { fieldName =>
val field = clazz.getDeclaredField(fieldName)
field.setAccessible(true)
field.get(p)
}
}
}
assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: " +
fieldNames.mkString(", ") + s", values: " + fieldValues.mkString(", "))
("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
@ -1038,6 +1051,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
}.toList
} catch {
case _: RuntimeException => null
case _: ReflectiveOperationException => null
}
case _ => JNull
}

View file

@ -95,6 +95,8 @@ case class FakeLeafPlan(child: LogicalPlan)
override def output: Seq[Attribute] = child.output
}
case class FakeCurryingProduct(x: Expression)(val y: Int)
class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@ -622,6 +624,21 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
"num-children" -> 0,
"arg" -> "1"
)))))
// Convert currying product contains TreeNode to JSON.
assertJSON(
FakeCurryingProduct(Literal(1))(1),
JObject(
"product-class" -> classOf[FakeCurryingProduct].getName,
"x" -> List(
JObject(
"class" -> JString(classOf[Literal].getName),
"num-children" -> 0,
"value" -> "1",
"dataType" -> "integer")),
"y" -> 1
)
)
}
test("toJSON should not throws java.lang.StackOverflowError") {