[SPARK-7180] [SPARK-8090] [SPARK-8091] Fix a number of SerializationDebugger bugs and limitations

This PR solves three SerializationDebugger issues.
* SPARK-7180 - SerializationDebugger fails with ArrayOutOfBoundsException
* SPARK-8090 - SerializationDebugger does not handle classes with writeReplace correctly
* SPARK-8091 - SerializationDebugger does not handle classes with writeObject method

The solutions for each are explained as follows
* SPARK-7180 - The wrong slot desc was used for getting the value of the fields in the object being tested.
* SPARK-8090 - Test the type of the replaced object.
* SPARK-8091 - Use a dummy ObjectOutputStream to collect all the objects written by the writeObject() method, and then test those objects as usual.

I also added more tests in the testsuite to increase code coverage. For example, added tests for cases where there are not serializability issues.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #6625 from tdas/SPARK-7180 and squashes the following commits:

c7cb046 [Tathagata Das] Addressed comments on docs
ae212c8 [Tathagata Das] Improved docs
304c97b [Tathagata Das] Fixed build error
26b5179 [Tathagata Das] more tests.....92% line coverage
7e2fdcf [Tathagata Das] Added more tests
d1967fb [Tathagata Das] Added comments.
da75d34 [Tathagata Das] Removed unnecessary lines.
50a608d [Tathagata Das] Fixed bugs and added support for writeObject
This commit is contained in:
Tathagata Das 2015-06-19 10:52:30 -07:00 committed by Andrew Or
parent 3415fb978b
commit 4b2c793a27
3 changed files with 223 additions and 12 deletions

View file

@ -17,7 +17,7 @@
package org.apache.spark.serializer package org.apache.spark.serializer
import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} import java.io._
import java.lang.reflect.{Field, Method} import java.lang.reflect.{Field, Method}
import java.security.AccessController import java.security.AccessController
@ -62,7 +62,7 @@ private[spark] object SerializationDebugger extends Logging {
* *
* It does not yet handle writeObject override, but that shouldn't be too hard to do either. * It does not yet handle writeObject override, but that shouldn't be too hard to do either.
*/ */
def find(obj: Any): List[String] = { private[serializer] def find(obj: Any): List[String] = {
new SerializationDebugger().visit(obj, List.empty) new SerializationDebugger().visit(obj, List.empty)
} }
@ -125,6 +125,12 @@ private[spark] object SerializationDebugger extends Logging {
return List.empty return List.empty
} }
/**
* Visit an externalizable object.
* Since writeExternal() can choose to add arbitrary objects at the time of serialization,
* the only way to capture all the objects it will serialize is by using a
* dummy ObjectOutput that collects all the relevant objects for further testing.
*/
private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] =
{ {
val fieldList = new ListObjectOutput val fieldList = new ListObjectOutput
@ -145,17 +151,50 @@ private[spark] object SerializationDebugger extends Logging {
// An object contains multiple slots in serialization. // An object contains multiple slots in serialization.
// Get the slots and visit fields in all of them. // Get the slots and visit fields in all of them.
val (finalObj, desc) = findObjectAndDescriptor(o) val (finalObj, desc) = findObjectAndDescriptor(o)
// If the object has been replaced using writeReplace(),
// then call visit() on it again to test its type again.
if (!finalObj.eq(o)) {
return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack)
}
// Every class is associated with one or more "slots", each slot refers to the parent
// classes of this class. These slots are used by the ObjectOutputStream
// serialization code to recursively serialize the fields of an object and
// its parent classes. For example, if there are the following classes.
//
// class ParentClass(parentField: Int)
// class ChildClass(childField: Int) extends ParentClass(1)
//
// Then serializing the an object Obj of type ChildClass requires first serializing the fields
// of ParentClass (that is, parentField), and then serializing the fields of ChildClass
// (that is, childField). Correspondingly, there will be two slots related to this object:
//
// 1. ParentClass slot, which will be used to serialize parentField of Obj
// 2. ChildClass slot, which will be used to serialize childField fields of Obj
//
// The following code uses the description of each slot to find the fields in the
// corresponding object to visit.
//
val slotDescs = desc.getSlotDescs val slotDescs = desc.getSlotDescs
var i = 0 var i = 0
while (i < slotDescs.length) { while (i < slotDescs.length) {
val slotDesc = slotDescs(i) val slotDesc = slotDescs(i)
if (slotDesc.hasWriteObjectMethod) { if (slotDesc.hasWriteObjectMethod) {
// TODO: Handle classes that specify writeObject method. // If the class type corresponding to current slot has writeObject() defined,
// then its not obvious which fields of the class will be serialized as the writeObject()
// can choose arbitrary fields for serialization. This case is handled separately.
val elem = s"writeObject data (class: ${slotDesc.getName})"
val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack)
if (childStack.nonEmpty) {
return childStack
}
} else { } else {
// Visit all the fields objects of the class corresponding to the current slot.
val fields: Array[ObjectStreamField] = slotDesc.getFields val fields: Array[ObjectStreamField] = slotDesc.getFields
val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields)
val numPrims = fields.length - objFieldValues.length val numPrims = fields.length - objFieldValues.length
desc.getObjFieldValues(finalObj, objFieldValues) slotDesc.getObjFieldValues(finalObj, objFieldValues)
var j = 0 var j = 0
while (j < objFieldValues.length) { while (j < objFieldValues.length) {
@ -169,18 +208,54 @@ private[spark] object SerializationDebugger extends Logging {
} }
j += 1 j += 1
} }
} }
i += 1 i += 1
} }
return List.empty return List.empty
} }
/**
* Visit a serializable object which has the writeObject() defined.
* Since writeObject() can choose to add arbitrary objects at the time of serialization,
* the only way to capture all the objects it will serialize is by using a
* dummy ObjectOutputStream that collects all the relevant fields for further testing.
* This is similar to how externalizable objects are visited.
*/
private def visitSerializableWithWriteObjectMethod(
o: Object, stack: List[String]): List[String] = {
val innerObjectsCatcher = new ListObjectOutputStream
var notSerializableFound = false
try {
innerObjectsCatcher.writeObject(o)
} catch {
case io: IOException =>
notSerializableFound = true
}
// If something was not serializable, then visit the captured objects.
// Otherwise, all the captured objects are safely serializable, so no need to visit them.
// As an optimization, just added them to the visited list.
if (notSerializableFound) {
val innerObjects = innerObjectsCatcher.outputArray
var k = 0
while (k < innerObjects.length) {
val childStack = visit(innerObjects(k), stack)
if (childStack.nonEmpty) {
return childStack
}
k += 1
}
} else {
visited ++= innerObjectsCatcher.outputArray
}
return List.empty
}
} }
/** /**
* Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles
* writeReplace in Serializable. It starts with the object itself, and keeps calling the * writeReplace in Serializable. It starts with the object itself, and keeps calling the
* writeReplace method until there is no more * writeReplace method until there is no more.
*/ */
@tailrec @tailrec
private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = {
@ -220,6 +295,31 @@ private[spark] object SerializationDebugger extends Logging {
override def writeByte(i: Int): Unit = {} override def writeByte(i: Int): Unit = {}
} }
/** An output stream that emulates /dev/null */
private class NullOutputStream extends OutputStream {
override def write(b: Int) { }
}
/**
* A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns
* them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()`
* method which gets called on every object, only if replacing is enabled. So this subclass
* of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that
* are being serializabled. The serialized bytes are ignored by sending them to a
* [[NullOutputStream]], which acts like a /dev/null.
*/
private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) {
private val output = new mutable.ArrayBuffer[Any]
this.enableReplaceObject(true)
def outputArray: Array[Any] = output.toArray
override def replaceObject(obj: Object): Object = {
output += obj
obj
}
}
/** An implicit class that allows us to call private methods of ObjectStreamClass. */ /** An implicit class that allows us to call private methods of ObjectStreamClass. */
implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal {
def getSlotDescs: Array[ObjectStreamClass] = { def getSlotDescs: Array[ObjectStreamClass] = {

View file

@ -17,7 +17,7 @@
package org.apache.spark.serializer package org.apache.spark.serializer
import java.io.{ObjectOutput, ObjectInput} import java.io._
import org.scalatest.BeforeAndAfterEach import org.scalatest.BeforeAndAfterEach
@ -98,7 +98,7 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach {
} }
test("externalizable class writing out not serializable object") { test("externalizable class writing out not serializable object") {
val s = find(new ExternalizableClass) val s = find(new ExternalizableClass(new SerializableClass2(new NotSerializable)))
assert(s.size === 5) assert(s.size === 5)
assert(s(0).contains("NotSerializable")) assert(s(0).contains("NotSerializable"))
assert(s(1).contains("objectField")) assert(s(1).contains("objectField"))
@ -106,6 +106,93 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach {
assert(s(3).contains("writeExternal")) assert(s(3).contains("writeExternal"))
assert(s(4).contains("ExternalizableClass")) assert(s(4).contains("ExternalizableClass"))
} }
test("externalizable class writing out serializable objects") {
assert(find(new ExternalizableClass(new SerializableClass1)).isEmpty)
}
test("object containing writeReplace() which returns not serializable object") {
val s = find(new SerializableClassWithWriteReplace(new NotSerializable))
assert(s.size === 3)
assert(s(0).contains("NotSerializable"))
assert(s(1).contains("writeReplace"))
assert(s(2).contains("SerializableClassWithWriteReplace"))
}
test("object containing writeReplace() which returns serializable object") {
assert(find(new SerializableClassWithWriteReplace(new SerializableClass1)).isEmpty)
}
test("object containing writeObject() and not serializable field") {
val s = find(new SerializableClassWithWriteObject(new NotSerializable))
assert(s.size === 3)
assert(s(0).contains("NotSerializable"))
assert(s(1).contains("writeObject data"))
assert(s(2).contains("SerializableClassWithWriteObject"))
}
test("object containing writeObject() and serializable field") {
assert(find(new SerializableClassWithWriteObject(new SerializableClass1)).isEmpty)
}
test("object of serializable subclass with more fields than superclass (SPARK-7180)") {
// This should not throw ArrayOutOfBoundsException
find(new SerializableSubclass(new SerializableClass1))
}
test("crazy nested objects") {
def findAndAssert(shouldSerialize: Boolean, obj: Any): Unit = {
val s = find(obj)
if (shouldSerialize) {
assert(s.isEmpty)
} else {
assert(s.nonEmpty)
assert(s.head.contains("NotSerializable"))
}
}
findAndAssert(false,
new SerializableClassWithWriteReplace(new ExternalizableClass(new SerializableSubclass(
new SerializableArray(
Array(new SerializableClass1, new SerializableClass2(new NotSerializable))
)
)))
)
findAndAssert(true,
new SerializableClassWithWriteReplace(new ExternalizableClass(new SerializableSubclass(
new SerializableArray(
Array(new SerializableClass1, new SerializableClass2(new SerializableClass1))
)
)))
)
}
test("improveException") {
val e = SerializationDebugger.improveException(
new SerializableClass2(new NotSerializable), new NotSerializableException("someClass"))
assert(e.getMessage.contains("someClass")) // original exception message should be present
assert(e.getMessage.contains("SerializableClass2")) // found debug trace should be present
}
test("improveException with error in debugger") {
// Object that throws exception in the SerializationDebugger
val o = new SerializableClass1 {
private def writeReplace(): Object = {
throw new Exception()
}
}
withClue("requirement: SerializationDebugger should fail trying debug this object") {
intercept[Exception] {
SerializationDebugger.find(o)
}
}
val originalException = new NotSerializableException("someClass")
// verify thaht original exception is returned on failure
assert(SerializationDebugger.improveException(o, originalException).eq(originalException))
}
} }
@ -118,10 +205,34 @@ class SerializableClass2(val objectField: Object) extends Serializable
class SerializableArray(val arrayField: Array[Object]) extends Serializable class SerializableArray(val arrayField: Array[Object]) extends Serializable
class ExternalizableClass extends java.io.Externalizable { class SerializableSubclass(val objectField: Object) extends SerializableClass1
class SerializableClassWithWriteObject(val objectField: Object) extends Serializable {
val serializableObjectField = new SerializableClass1
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = {
oos.defaultWriteObject()
}
}
class SerializableClassWithWriteReplace(@transient replacementFieldObject: Object)
extends Serializable {
private def writeReplace(): Object = {
replacementFieldObject
}
}
class ExternalizableClass(objectField: Object) extends java.io.Externalizable {
val serializableObjectField = new SerializableClass1
override def writeExternal(out: ObjectOutput): Unit = { override def writeExternal(out: ObjectOutput): Unit = {
out.writeInt(1) out.writeInt(1)
out.writeObject(new SerializableClass2(new NotSerializable)) out.writeObject(serializableObjectField)
out.writeObject(objectField)
} }
override def readExternal(in: ObjectInput): Unit = {} override def readExternal(in: ObjectInput): Unit = {}

View file

@ -549,8 +549,8 @@ class StreamingContext private[streaming] (
case e: NotSerializableException => case e: NotSerializableException =>
throw new NotSerializableException( throw new NotSerializableException(
"DStream checkpointing has been enabled but the DStreams with their functions " + "DStream checkpointing has been enabled but the DStreams with their functions " +
"are not serializable\nSerialization stack:\n" + "are not serializable\n" +
SerializationDebugger.find(checkpoint).map("\t- " + _).mkString("\n") SerializationDebugger.improveException(checkpoint, e).getMessage()
) )
} }
} }