[SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields

## What changes were proposed in this pull request?

When the given closure uses some fields defined in super class, `ClosureCleaner` can't figure them and don't set it properly. Those fields will be in null values.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #19556 from viirya/SPARK-22328.
This commit is contained in:
Liang-Chi Hsieh 2017-10-26 21:41:45 +01:00 committed by Wenchen Fan
parent 0e9a750a8d
commit 4f8dc6b01e
2 changed files with 133 additions and 12 deletions

View file

@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging {
(seen - obj.getClass).toList
}
/** Initializes the accessed fields for outer classes and their super classes. */
private def initAccessedFields(
accessedFields: Map[Class[_], Set[String]],
outerClasses: Seq[Class[_]]): Unit = {
for (cls <- outerClasses) {
var currentClass = cls
assert(currentClass != null, "The outer class can't be null.")
while (currentClass != null) {
accessedFields(currentClass) = Set.empty[String]
currentClass = currentClass.getSuperclass()
}
}
}
/** Sets accessed fields for given class in clone object based on given object. */
private def setAccessedFields(
outerClass: Class[_],
clone: AnyRef,
obj: AnyRef,
accessedFields: Map[Class[_], Set[String]]): Unit = {
for (fieldName <- accessedFields(outerClass)) {
val field = outerClass.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
field.set(clone, value)
}
}
/** Clones a given object and sets accessed fields in cloned object. */
private def cloneAndSetFields(
parent: AnyRef,
obj: AnyRef,
outerClass: Class[_],
accessedFields: Map[Class[_], Set[String]]): AnyRef = {
val clone = instantiateClass(outerClass, parent)
var currentClass = outerClass
assert(currentClass != null, "The outer class can't be null.")
while (currentClass != null) {
setAccessedFields(currentClass, clone, obj, accessedFields)
currentClass = currentClass.getSuperclass()
}
clone
}
/**
* Clean the given closure in place.
*
@ -202,9 +250,8 @@ private[spark] object ClosureCleaner extends Logging {
logDebug(s" + populating accessed fields because this is the starting closure")
// Initialize accessed fields with the outer classes first
// This step is needed to associate the fields to the correct classes later
for (cls <- outerClasses) {
accessedFields(cls) = Set.empty[String]
}
initAccessedFields(accessedFields, outerClasses)
// Populate accessed fields by visiting all fields and methods accessed by this and
// all of its inner closures. If transitive cleaning is enabled, this may recursively
// visits methods that belong to other classes in search of transitively referenced fields.
@ -250,13 +297,8 @@ private[spark] object ClosureCleaner extends Logging {
// required fields from the original object. We need the parent here because the Java
// language specification requires the first constructor parameter of any closure to be
// its enclosing object.
val clone = instantiateClass(cls, parent)
for (fieldName <- accessedFields(cls)) {
val field = cls.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
field.set(clone, value)
}
val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
// If transitive cleaning is enabled, we recursively clean any enclosing closure using
// the already populated accessed fields map of the starting closure
if (cleanTransitively && isClosure(clone.getClass)) {
@ -395,8 +437,15 @@ private[util] class FieldAccessFinder(
if (!visitedMethods.contains(m)) {
// Keep track of visited methods to avoid potential infinite cycles
visitedMethods += m
ClosureCleaner.getClassReader(cl).accept(
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
var currentClass = cl
assert(currentClass != null, "The outer class can't be null.")
while (currentClass != null) {
ClosureCleaner.getClassReader(currentClass).accept(
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
currentClass = currentClass.getSuperclass()
}
}
}
}

View file

@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite {
test("createNullValue") {
new TestCreateNullValue().run()
}
test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") {
val concreteObject = new TestAbstractClass {
val n2 = 222
val s2 = "bbb"
val d2 = 2.0d
def run(): Seq[(Int, Int, String, String, Double, Double)] = {
withSpark(new SparkContext("local", "test")) { sc =>
val rdd = sc.parallelize(1 to 1)
body(rdd)
}
}
def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ =>
(n1, n2, s1, s2, d1, d2)
}.collect()
}
assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
}
test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") {
val concreteObject = new TestAbstractClass2 {
val n2 = 222
val s2 = "bbb"
val d2 = 2.0d
def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2)
}
withSpark(new SparkContext("local", "test")) { sc =>
val rdd = sc.parallelize(1 to 1).map(concreteObject.getData)
assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
}
}
test("SPARK-22328: multiple outer classes have the same parent class") {
val concreteObject = new TestAbstractClass2 {
val innerObject = new TestAbstractClass2 {
override val n1 = 222
override val s1 = "bbb"
}
val innerObject2 = new TestAbstractClass2 {
override val n1 = 444
val n3 = 333
val s3 = "ccc"
val d3 = 3.0d
def getData: Int => (Int, Int, String, String, Double, Double, Int, String) =
_ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1)
}
}
withSpark(new SparkContext("local", "test")) { sc =>
val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData)
assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb")))
}
}
}
// A non-serializable class we create in closures to make sure that we aren't
@ -377,3 +434,18 @@ class TestCreateNullValue {
nestedClosure()
}
}
abstract class TestAbstractClass extends Serializable {
val n1 = 111
val s1 = "aaa"
protected val d1 = 1.0d
def run(): Seq[(Int, Int, String, String, Double, Double)]
def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)]
}
abstract class TestAbstractClass2 extends Serializable {
val n1 = 111
val s1 = "aaa"
protected val d1 = 1.0d
}