[SPARK-4454] Properly synchronize accesses to DAGScheduler cacheLocs map

This patch addresses a race condition in DAGScheduler by properly synchronizing accesses to its `cacheLocs` map.

This map is accessed by the `getCacheLocs` and `clearCacheLocs()` methods, which can be called by separate threads, since DAGScheduler's `getPreferredLocs()` method is called by SparkContext and indirectly calls `getCacheLocs()`.  If this map is cleared by the DAGScheduler event processing thread while a user thread is submitting a job and computing preferred locations, then this can cause the user thread to throw "NoSuchElementException: key not found" errors.

Most accesses to DAGScheduler's internal state do not need synchronization because that state is only accessed from the event processing loop's thread.  An alternative approach to fixing this bug would be to refactor this code so that SparkContext sends the DAGScheduler a message in order to get the list of preferred locations.  However, this would involve more extensive changes to this code and would be significantly harder to backport to maintenance branches since some of the related code has undergone significant refactoring (e.g. the introduction of EventLoop).  Since `cacheLocs` is the only state that's accessed in this way, adding simple synchronization seems like a better short-term fix.

See #3345 for additional context.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #4660 from JoshRosen/SPARK-4454 and squashes the following commits:

12d64ba [Josh Rosen] Properly synchronize accesses to DAGScheduler cacheLocs map.
This commit is contained in:
Josh Rosen 2015-02-17 17:39:58 -08:00 committed by Patrick Wendell
parent ae6cfb3acd
commit d46d6246d2

View file

@ -98,7 +98,13 @@ class DAGScheduler(
private[scheduler] val activeJobs = new HashSet[ActiveJob] private[scheduler] val activeJobs = new HashSet[ActiveJob]
// Contains the locations that each RDD's partitions are cached on /**
* Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
* and its values are arrays indexed by partition numbers. Each array value is the set of
* locations where that RDD partition is cached.
*
* All accesses to this map should be guarded by synchronizing on it (see SPARK-4454).
*/
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
// For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
@ -183,18 +189,17 @@ class DAGScheduler(
eventProcessLoop.post(TaskSetFailed(taskSet, reason)) eventProcessLoop.post(TaskSetFailed(taskSet, reason))
} }
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized {
if (!cacheLocs.contains(rdd.id)) { cacheLocs.getOrElseUpdate(rdd.id, {
val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id => blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
} }
} })
cacheLocs(rdd.id)
} }
private def clearCacheLocs() { private def clearCacheLocs(): Unit = cacheLocs.synchronized {
cacheLocs.clear() cacheLocs.clear()
} }
@ -1276,17 +1281,26 @@ class DAGScheduler(
} }
/** /**
* Synchronized method that might be called from other threads. * Gets the locality information associated with a partition of a particular RDD.
*
* This method is thread-safe and is called from both DAGScheduler and SparkContext.
*
* @param rdd whose partitions are to be looked at * @param rdd whose partitions are to be looked at
* @param partition to lookup locality information for * @param partition to lookup locality information for
* @return list of machines that are preferred by the partition * @return list of machines that are preferred by the partition
*/ */
private[spark] private[spark]
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
getPreferredLocsInternal(rdd, partition, new HashSet) getPreferredLocsInternal(rdd, partition, new HashSet)
} }
/** Recursive implementation for getPreferredLocs. */ /**
* Recursive implementation for getPreferredLocs.
*
* This method is thread-safe because it only accesses DAGScheduler state through thread-safe
* methods (getCacheLocs()); please be careful when modifying this method, because any new
* DAGScheduler state accessed by it may require additional synchronization.
*/
private def getPreferredLocsInternal( private def getPreferredLocsInternal(
rdd: RDD[_], rdd: RDD[_],
partition: Int, partition: Int,