1717
1818package org .apache .spark .serializer
1919
20- import java .io .{ NotSerializableException , ObjectOutput , ObjectStreamClass , ObjectStreamField }
20+ import java .io ._
2121import java .lang .reflect .{Field , Method }
2222import java .security .AccessController
2323
@@ -62,7 +62,7 @@ private[spark] object SerializationDebugger extends Logging {
6262 *
6363 * It does not yet handle writeObject override, but that shouldn't be too hard to do either.
6464 */
65- def find (obj : Any ): List [String ] = {
65+ private [serializer] def find (obj : Any ): List [String ] = {
6666 new SerializationDebugger ().visit(obj, List .empty)
6767 }
6868
@@ -125,6 +125,12 @@ private[spark] object SerializationDebugger extends Logging {
125125 return List .empty
126126 }
127127
128+ /**
129+ * Visit an externalizable object.
130+ * Since writeExternal() can choose to add arbitrary objects at the time of serialization,
131+ * the only way to capture all the objects it will serialize is by using a
132+ * dummy ObjectOutput that collects all the relevant objects for further testing.
133+ */
128134 private def visitExternalizable (o : java.io.Externalizable , stack : List [String ]): List [String ] =
129135 {
130136 val fieldList = new ListObjectOutput
@@ -145,17 +151,50 @@ private[spark] object SerializationDebugger extends Logging {
145151 // An object contains multiple slots in serialization.
146152 // Get the slots and visit fields in all of them.
147153 val (finalObj, desc) = findObjectAndDescriptor(o)
154+
155+ // If the object has been replaced using writeReplace(),
156+ // then call visit() on it again to test its type again.
157+ if (! finalObj.eq(o)) {
158+ return visit(finalObj, s " writeReplace data (class: ${finalObj.getClass.getName}) " :: stack)
159+ }
160+
161+ // Every class is associated with one or more "slots", each slot refers to the parent
162+ // classes of this class. These slots are used by the ObjectOutputStream
163+ // serialization code to recursively serialize the fields of an object and
164+ // its parent classes. For example, if there are the following classes.
165+ //
166+ // class ParentClass(parentField: Int)
167+ // class ChildClass(childField: Int) extends ParentClass(1)
168+ //
169+ // Then serializing the an object Obj of type ChildClass requires first serializing the fields
170+ // of ParentClass (that is, parentField), and then serializing the fields of ChildClass
171+ // (that is, childField). Correspondingly, there will be two slots related to this object:
172+ //
173+ // 1. ParentClass slot, which will be used to serialize parentField of Obj
174+ // 2. ChildClass slot, which will be used to serialize childField fields of Obj
175+ //
176+ // The following code uses the description of each slot to find the fields in the
177+ // corresponding object to visit.
178+ //
148179 val slotDescs = desc.getSlotDescs
149180 var i = 0
150181 while (i < slotDescs.length) {
151182 val slotDesc = slotDescs(i)
152183 if (slotDesc.hasWriteObjectMethod) {
153- // TODO: Handle classes that specify writeObject method.
184+ // If the class type corresponding to current slot has writeObject() defined,
185+ // then its not obvious which fields of the class will be serialized as the writeObject()
186+ // can choose arbitrary fields for serialization. This case is handled separately.
187+ val elem = s " writeObject data (class: ${slotDesc.getName}) "
188+ val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack)
189+ if (childStack.nonEmpty) {
190+ return childStack
191+ }
154192 } else {
193+ // Visit all the fields objects of the class corresponding to the current slot.
155194 val fields : Array [ObjectStreamField ] = slotDesc.getFields
156195 val objFieldValues : Array [Object ] = new Array [Object ](slotDesc.getNumObjFields)
157196 val numPrims = fields.length - objFieldValues.length
158- desc .getObjFieldValues(finalObj, objFieldValues)
197+ slotDesc .getObjFieldValues(finalObj, objFieldValues)
159198
160199 var j = 0
161200 while (j < objFieldValues.length) {
@@ -169,18 +208,54 @@ private[spark] object SerializationDebugger extends Logging {
169208 }
170209 j += 1
171210 }
172-
173211 }
174212 i += 1
175213 }
176214 return List .empty
177215 }
216+
217+ /**
218+ * Visit a serializable object which has the writeObject() defined.
219+ * Since writeObject() can choose to add arbitrary objects at the time of serialization,
220+ * the only way to capture all the objects it will serialize is by using a
221+ * dummy ObjectOutputStream that collects all the relevant fields for further testing.
222+ * This is similar to how externalizable objects are visited.
223+ */
224+ private def visitSerializableWithWriteObjectMethod (
225+ o : Object , stack : List [String ]): List [String ] = {
226+ val innerObjectsCatcher = new ListObjectOutputStream
227+ var notSerializableFound = false
228+ try {
229+ innerObjectsCatcher.writeObject(o)
230+ } catch {
231+ case io : IOException =>
232+ notSerializableFound = true
233+ }
234+
235+ // If something was not serializable, then visit the captured objects.
236+ // Otherwise, all the captured objects are safely serializable, so no need to visit them.
237+ // As an optimization, just added them to the visited list.
238+ if (notSerializableFound) {
239+ val innerObjects = innerObjectsCatcher.outputArray
240+ var k = 0
241+ while (k < innerObjects.length) {
242+ val childStack = visit(innerObjects(k), stack)
243+ if (childStack.nonEmpty) {
244+ return childStack
245+ }
246+ k += 1
247+ }
248+ } else {
249+ visited ++= innerObjectsCatcher.outputArray
250+ }
251+ return List .empty
252+ }
178253 }
179254
180255 /**
181256 * Find the object to serialize and the associated [[ObjectStreamClass ]]. This method handles
182257 * writeReplace in Serializable. It starts with the object itself, and keeps calling the
183- * writeReplace method until there is no more
258+ * writeReplace method until there is no more.
184259 */
185260 @ tailrec
186261 private def findObjectAndDescriptor (o : Object ): (Object , ObjectStreamClass ) = {
@@ -220,6 +295,31 @@ private[spark] object SerializationDebugger extends Logging {
220295 override def writeByte (i : Int ): Unit = {}
221296 }
222297
298+ /** An output stream that emulates /dev/null */
299+ private class NullOutputStream extends OutputStream {
300+ override def write (b : Int ) { }
301+ }
302+
303+ /**
304+ * A dummy [[ObjectOutputStream ]] that saves the list of objects written to it and returns
305+ * them through `outputArray`. This works by using the [[ObjectOutputStream ]]'s `replaceObject()`
306+ * method which gets called on every object, only if replacing is enabled. So this subclass
307+ * of [[ObjectOutputStream ]] enabled replacing, and uses replaceObject to get the objects that
308+ * are being serializabled. The serialized bytes are ignored by sending them to a
309+ * [[NullOutputStream ]], which acts like a /dev/null.
310+ */
311+ private class ListObjectOutputStream extends ObjectOutputStream (new NullOutputStream ) {
312+ private val output = new mutable.ArrayBuffer [Any ]
313+ this .enableReplaceObject(true )
314+
315+ def outputArray : Array [Any ] = output.toArray
316+
317+ override def replaceObject (obj : Object ): Object = {
318+ output += obj
319+ obj
320+ }
321+ }
322+
223323 /** An implicit class that allows us to call private methods of ObjectStreamClass. */
224324 implicit class ObjectStreamClassMethods (val desc : ObjectStreamClass ) extends AnyVal {
225325 def getSlotDescs : Array [ObjectStreamClass ] = {
0 commit comments