[SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields
## What changes were proposed in this pull request? When the given closure uses some fields defined in super class, `ClosureCleaner` can't figure them and don't set it properly. Those fields will be in null values. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #19556 from viirya/SPARK-22328.
This commit is contained in:
parent
0e9a750a8d
commit
4f8dc6b01e
|
@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging {
|
|||
(seen - obj.getClass).toList
|
||||
}
|
||||
|
||||
/** Initializes the accessed fields for outer classes and their super classes. */
|
||||
private def initAccessedFields(
|
||||
accessedFields: Map[Class[_], Set[String]],
|
||||
outerClasses: Seq[Class[_]]): Unit = {
|
||||
for (cls <- outerClasses) {
|
||||
var currentClass = cls
|
||||
assert(currentClass != null, "The outer class can't be null.")
|
||||
|
||||
while (currentClass != null) {
|
||||
accessedFields(currentClass) = Set.empty[String]
|
||||
currentClass = currentClass.getSuperclass()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Sets accessed fields for given class in clone object based on given object. */
|
||||
private def setAccessedFields(
|
||||
outerClass: Class[_],
|
||||
clone: AnyRef,
|
||||
obj: AnyRef,
|
||||
accessedFields: Map[Class[_], Set[String]]): Unit = {
|
||||
for (fieldName <- accessedFields(outerClass)) {
|
||||
val field = outerClass.getDeclaredField(fieldName)
|
||||
field.setAccessible(true)
|
||||
val value = field.get(obj)
|
||||
field.set(clone, value)
|
||||
}
|
||||
}
|
||||
|
||||
/** Clones a given object and sets accessed fields in cloned object. */
|
||||
private def cloneAndSetFields(
|
||||
parent: AnyRef,
|
||||
obj: AnyRef,
|
||||
outerClass: Class[_],
|
||||
accessedFields: Map[Class[_], Set[String]]): AnyRef = {
|
||||
val clone = instantiateClass(outerClass, parent)
|
||||
|
||||
var currentClass = outerClass
|
||||
assert(currentClass != null, "The outer class can't be null.")
|
||||
|
||||
while (currentClass != null) {
|
||||
setAccessedFields(currentClass, clone, obj, accessedFields)
|
||||
currentClass = currentClass.getSuperclass()
|
||||
}
|
||||
|
||||
clone
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean the given closure in place.
|
||||
*
|
||||
|
@ -202,9 +250,8 @@ private[spark] object ClosureCleaner extends Logging {
|
|||
logDebug(s" + populating accessed fields because this is the starting closure")
|
||||
// Initialize accessed fields with the outer classes first
|
||||
// This step is needed to associate the fields to the correct classes later
|
||||
for (cls <- outerClasses) {
|
||||
accessedFields(cls) = Set.empty[String]
|
||||
}
|
||||
initAccessedFields(accessedFields, outerClasses)
|
||||
|
||||
// Populate accessed fields by visiting all fields and methods accessed by this and
|
||||
// all of its inner closures. If transitive cleaning is enabled, this may recursively
|
||||
// visits methods that belong to other classes in search of transitively referenced fields.
|
||||
|
@ -250,13 +297,8 @@ private[spark] object ClosureCleaner extends Logging {
|
|||
// required fields from the original object. We need the parent here because the Java
|
||||
// language specification requires the first constructor parameter of any closure to be
|
||||
// its enclosing object.
|
||||
val clone = instantiateClass(cls, parent)
|
||||
for (fieldName <- accessedFields(cls)) {
|
||||
val field = cls.getDeclaredField(fieldName)
|
||||
field.setAccessible(true)
|
||||
val value = field.get(obj)
|
||||
field.set(clone, value)
|
||||
}
|
||||
val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
|
||||
|
||||
// If transitive cleaning is enabled, we recursively clean any enclosing closure using
|
||||
// the already populated accessed fields map of the starting closure
|
||||
if (cleanTransitively && isClosure(clone.getClass)) {
|
||||
|
@ -395,8 +437,15 @@ private[util] class FieldAccessFinder(
|
|||
if (!visitedMethods.contains(m)) {
|
||||
// Keep track of visited methods to avoid potential infinite cycles
|
||||
visitedMethods += m
|
||||
ClosureCleaner.getClassReader(cl).accept(
|
||||
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
|
||||
|
||||
var currentClass = cl
|
||||
assert(currentClass != null, "The outer class can't be null.")
|
||||
|
||||
while (currentClass != null) {
|
||||
ClosureCleaner.getClassReader(currentClass).accept(
|
||||
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
|
||||
currentClass = currentClass.getSuperclass()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite {
|
|||
test("createNullValue") {
|
||||
new TestCreateNullValue().run()
|
||||
}
|
||||
|
||||
test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") {
|
||||
val concreteObject = new TestAbstractClass {
|
||||
val n2 = 222
|
||||
val s2 = "bbb"
|
||||
val d2 = 2.0d
|
||||
|
||||
def run(): Seq[(Int, Int, String, String, Double, Double)] = {
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val rdd = sc.parallelize(1 to 1)
|
||||
body(rdd)
|
||||
}
|
||||
}
|
||||
|
||||
def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ =>
|
||||
(n1, n2, s1, s2, d1, d2)
|
||||
}.collect()
|
||||
}
|
||||
assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
|
||||
}
|
||||
|
||||
test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") {
|
||||
val concreteObject = new TestAbstractClass2 {
|
||||
val n2 = 222
|
||||
val s2 = "bbb"
|
||||
val d2 = 2.0d
|
||||
def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2)
|
||||
}
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val rdd = sc.parallelize(1 to 1).map(concreteObject.getData)
|
||||
assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-22328: multiple outer classes have the same parent class") {
|
||||
val concreteObject = new TestAbstractClass2 {
|
||||
|
||||
val innerObject = new TestAbstractClass2 {
|
||||
override val n1 = 222
|
||||
override val s1 = "bbb"
|
||||
}
|
||||
|
||||
val innerObject2 = new TestAbstractClass2 {
|
||||
override val n1 = 444
|
||||
val n3 = 333
|
||||
val s3 = "ccc"
|
||||
val d3 = 3.0d
|
||||
|
||||
def getData: Int => (Int, Int, String, String, Double, Double, Int, String) =
|
||||
_ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1)
|
||||
}
|
||||
}
|
||||
withSpark(new SparkContext("local", "test")) { sc =>
|
||||
val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData)
|
||||
assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A non-serializable class we create in closures to make sure that we aren't
|
||||
|
@ -377,3 +434,18 @@ class TestCreateNullValue {
|
|||
nestedClosure()
|
||||
}
|
||||
}
|
||||
|
||||
abstract class TestAbstractClass extends Serializable {
|
||||
val n1 = 111
|
||||
val s1 = "aaa"
|
||||
protected val d1 = 1.0d
|
||||
|
||||
def run(): Seq[(Int, Int, String, String, Double, Double)]
|
||||
def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)]
|
||||
}
|
||||
|
||||
abstract class TestAbstractClass2 extends Serializable {
|
||||
val n1 = 111
|
||||
val s1 = "aaa"
|
||||
protected val d1 = 1.0d
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue