[SPARK-3137][CORE] Replace the global TorrentBroadcast lock with fine grained KeyLock

### What changes were proposed in this pull request?

This PR provides a new lock mechanism `KeyLock` to lock  with a given key. Also use this new lock in `TorrentBroadcast` to avoid blocking tasks from fetching different broadcast values.

### Why are the changes needed?

`TorrentBroadcast.readObject` uses a global lock so only one task can be fetching the blocks at the same time. This is not optimal if we are running multiple stages concurrently because they should be able to independently fetch their own blocks.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #25612 from zsxwing/SPARK-3137.

Authored-by: Shixiong Zhu <zsxwing@gmail.com>
Signed-off-by: Shixiong Zhu <zsxwing@gmail.com>
This commit is contained in:
Shixiong Zhu 2019-09-03 14:09:07 -07:00
parent 5ea134c354
commit 89800931aa
No known key found for this signature in database
GPG key ID: 34400CF75FADFD94
4 changed files with 207 additions and 9 deletions

View file

@ -17,6 +17,7 @@
package org.apache.spark.broadcast
import java.util.Collections
import java.util.concurrent.atomic.AtomicLong
import scala.reflect.ClassTag
@ -55,9 +56,11 @@ private[spark] class BroadcastManager(
private val nextBroadcastId = new AtomicLong(0)
private[broadcast] val cachedValues = {
new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
}
private[broadcast] val cachedValues =
Collections.synchronizedMap(
new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
.asInstanceOf[java.util.Map[Any, Any]]
)
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
val bid = nextBroadcastId.getAndIncrement()

View file

@ -31,7 +31,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.util.Utils
import org.apache.spark.util.{KeyLock, Utils}
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
/**
@ -167,7 +167,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
releaseBlockManagerLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
@ -215,8 +215,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
private def readBroadcastBlock(): T = Utils.tryOrIOException {
val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
broadcastCache.synchronized {
TorrentBroadcast.torrentBroadcastLock.withLock(broadcastId) {
// As we only lock based on `broadcastId`, whenever using `broadcastCache`, we should only
// touch `broadcastId`.
val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
setConf(SparkEnv.get.conf)
@ -225,7 +227,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
case Some(blockResult) =>
if (blockResult.data.hasNext) {
val x = blockResult.data.next().asInstanceOf[T]
releaseLock(broadcastId)
releaseBlockManagerLock(broadcastId)
if (x != null) {
broadcastCache.put(broadcastId, x)
@ -270,7 +272,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
* If running in a task, register the given block's locks for release upon task completion.
* Otherwise, if not running in a task then immediately release the lock.
*/
private def releaseLock(blockId: BlockId): Unit = {
private def releaseBlockManagerLock(blockId: BlockId): Unit = {
val blockManager = SparkEnv.get.blockManager
Option(TaskContext.get()) match {
case Some(taskContext) =>
@ -290,6 +292,12 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
private object TorrentBroadcast extends Logging {
/**
* A [[KeyLock]] whose key is [[BroadcastBlockId]] to ensure there is only one thread fetching
* the same [[TorrentBroadcast]] block.
*/
private val torrentBroadcastLock = new KeyLock[BroadcastBlockId]
def blockifyObject[T: ClassTag](
obj: T,
blockSize: Int,

View file

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.concurrent.ConcurrentHashMap
/**
* A special locking mechanism to provide locking with a given key. By providing the same key
* (identity is tested using the `equals` method), we ensure there is only one `func` running at
* the same time.
*
* @tparam K the type of key to identify a lock. This type must implement `equals` and `hashCode`
* correctly as it will be the key type of an internal Map.
*/
private[spark] class KeyLock[K] {
private val lockMap = new ConcurrentHashMap[K, AnyRef]()
private def acquireLock(key: K): Unit = {
while (true) {
val lock = lockMap.putIfAbsent(key, new Object)
if (lock == null) return
lock.synchronized {
while (lockMap.get(key) eq lock) {
lock.wait()
}
}
}
}
private def releaseLock(key: K): Unit = {
val lock = lockMap.remove(key)
lock.synchronized {
lock.notifyAll()
}
}
/**
* Run `func` under a lock identified by the given key. Multiple calls with the same key
* (identity is tested using the `equals` method) will be locked properly to ensure there is only
* one `func` running at the same time.
*/
def withLock[T](key: K)(func: => T): T = {
if (key == null) {
throw new NullPointerException("key must not be null")
}
acquireLock(key)
try {
func
} finally {
releaseLock(key)
}
}
}

View file

@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.duration._
import org.scalatest.concurrent.{ThreadSignaler, TimeLimits}
import org.apache.spark.SparkFunSuite
class KeyLockSuite extends SparkFunSuite with TimeLimits {
// Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
private implicit val defaultSignaler = ThreadSignaler
private val foreverMs = 60 * 1000L
test("The same key should wait when its lock is held") {
val keyLock = new KeyLock[Object]
val numThreads = 10
// Create different objects that are equal
val keys = List.fill(numThreads)(List(1))
require(keys.tail.forall(_ ne keys.head) && keys.tail.forall(_ == keys.head))
// A latch to make `withLock` be called almost at the same time
val latch = new CountDownLatch(1)
// Track how many threads get the lock at the same time
val numThreadsHoldingLock = new AtomicInteger(0)
// Track how many functions get called
val numFuncCalled = new AtomicInteger(0)
@volatile var e: Throwable = null
val threads = (0 until numThreads).map { i =>
new Thread() {
override def run(): Unit = try {
latch.await(foreverMs, TimeUnit.MILLISECONDS)
keyLock.withLock(keys(i)) {
var cur = numThreadsHoldingLock.get()
if (cur != 0) {
e = new AssertionError(s"numThreadsHoldingLock is not 0: $cur")
}
cur = numThreadsHoldingLock.incrementAndGet()
if (cur != 1) {
e = new AssertionError(s"numThreadsHoldingLock is not 1: $cur")
}
cur = numThreadsHoldingLock.decrementAndGet()
if (cur != 0) {
e = new AssertionError(s"numThreadsHoldingLock is not 0: $cur")
}
numFuncCalled.incrementAndGet()
}
}
}
}
threads.foreach(_.start())
latch.countDown()
threads.foreach(_.join())
if (e != null) {
throw e
}
assert(numFuncCalled.get === numThreads)
}
test("A different key should not be locked") {
val keyLock = new KeyLock[Object]
val k1 = new Object
val k2 = new Object
// Start a thread to hold the lock for `k1` forever
val latch = new CountDownLatch(1)
val t = new Thread() {
override def run(): Unit = try {
keyLock.withLock(k1) {
latch.countDown()
Thread.sleep(foreverMs)
}
} catch {
case _: InterruptedException => // Ignore it as it's the exit signal
}
}
t.start()
try {
// Wait until the thread gets the lock for `k1`
if (!latch.await(foreverMs, TimeUnit.MILLISECONDS)) {
throw new TimeoutException("thread didn't get the lock")
}
var funcCalled = false
// Verify we can acquire the lock for `k2` and call `func`
failAfter(foreverMs.millis) {
keyLock.withLock(k2) {
funcCalled = true
}
}
assert(funcCalled, "func is not called")
} finally {
t.interrupt()
t.join()
}
}
}