[SPARK-15927] Eliminate redundant DAGScheduler code.

To try to eliminate redundant code to traverse the RDD dependency graph,
this PR creates a new function getShuffleDependencies that returns
shuffle dependencies that are immediate parents of a given RDD.  This
new function is used by getParentStages and
getAncestorShuffleDependencies.

Author: Kay Ousterhout <kayousterhout@gmail.com>

Closes #13646 from kayousterhout/SPARK-15927.
This commit is contained in:
Kay Ousterhout 2016-06-14 17:26:33 -07:00
parent dae4d5db21
commit 5d50d4f0f9
2 changed files with 74 additions and 39 deletions

View file

@ -378,59 +378,63 @@ class DAGScheduler(
* the provided firstJobId. * the provided firstJobId.
*/ */
private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
val parents = new HashSet[Stage] getShuffleDependencies(rdd).map { shuffleDep =>
val visited = new HashSet[RDD[_]] getShuffleMapStage(shuffleDep, firstJobId)
// We are manually maintaining a stack here to prevent StackOverflowError }.toList
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, firstJobId)
case _ =>
waitingForVisit.push(dep.rdd)
}
}
}
}
waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) {
visit(waitingForVisit.pop())
}
parents.toList
} }
/** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
val parents = new Stack[ShuffleDependency[_, _, _]] val ancestors = new Stack[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]] val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError // We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting // caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]] val waitingForVisit = new Stack[RDD[_]]
def visit(r: RDD[_]) { waitingForVisit.push(rdd)
if (!visited(r)) { while (waitingForVisit.nonEmpty) {
visited += r val toVisit = waitingForVisit.pop()
for (dep <- r.dependencies) { if (!visited(toVisit)) {
dep match { visited += toVisit
case shufDep: ShuffleDependency[_, _, _] => getShuffleDependencies(toVisit).foreach { shuffleDep =>
if (!shuffleToMapStage.contains(shufDep.shuffleId)) { if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) {
parents.push(shufDep) ancestors.push(shuffleDep)
} waitingForVisit.push(shuffleDep.rdd)
case _ => } // Otherwise, the dependency and its ancestors have already been registered.
}
waitingForVisit.push(dep.rdd)
} }
} }
} }
ancestors
}
/**
* Returns shuffle dependencies that are immediate parents of the given RDD.
*
* This function will not return more distant ancestors. For example, if C has a shuffle
* dependency on B which has a shuffle dependency on A:
*
* A <-- B <-- C
*
* calling this function with rdd C will only return the B <-- C dependency.
*
* This function is scheduler-visible for the purpose of unit testing.
*/
private[scheduler] def getShuffleDependencies(
rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
val parents = new HashSet[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
val waitingForVisit = new Stack[RDD[_]]
waitingForVisit.push(rdd) waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) { while (waitingForVisit.nonEmpty) {
visit(waitingForVisit.pop()) val toVisit = waitingForVisit.pop()
if (!visited(toVisit)) {
visited += toVisit
toVisit.dependencies.foreach {
case shuffleDep: ShuffleDependency[_, _, _] =>
parents += shuffleDep
case dependency =>
waitingForVisit.push(dependency.rdd)
}
}
} }
parents parents
} }

View file

@ -2023,6 +2023,37 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
assertDataStructuresEmpty() assertDataStructuresEmpty()
} }
/**
* Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that
* getShuffleDependencies correctly returns the direct shuffle dependencies of a particular
* RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s
* denotes a shuffle dependency):
*
* A <------------s---------,
* \
* B <--s-- C <--s-- D <--n---`-- E
*
* Here, the direct shuffle dependency of C is just the shuffle dependency on B. The direct
* shuffle dependencies of E are the shuffle dependency on A and the shuffle dependency on C.
*/
test("getShuffleDependencies correctly returns only direct shuffle parents") {
val rddA = new MyRDD(sc, 2, Nil)
val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1))
val rddB = new MyRDD(sc, 2, Nil)
val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1))
val rddC = new MyRDD(sc, 1, List(shuffleDepB))
val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(1))
val rddD = new MyRDD(sc, 1, List(shuffleDepC))
val narrowDepD = new OneToOneDependency(rddD)
val rddE = new MyRDD(sc, 1, List(shuffleDepA, narrowDepD), tracker = mapOutputTracker)
assert(scheduler.getShuffleDependencies(rddA) === Set())
assert(scheduler.getShuffleDependencies(rddB) === Set())
assert(scheduler.getShuffleDependencies(rddC) === Set(shuffleDepB))
assert(scheduler.getShuffleDependencies(rddD) === Set(shuffleDepC))
assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC))
}
/** /**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID. * Note that this checks only the host and not the executor ID.