diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 81e4c8f031..1b4e7ba510 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -163,7 +163,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( } private def canShuffleMergeBeEnabled(): Boolean = { - val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) + val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf, + // invoked at driver + isDriver = true) if (isPushShuffleEnabled && rdd.isBarrier()) { logWarning("Push-based shuffle is currently not supported for barrier stages") } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 24954e7674..ca1229a737 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -617,7 +617,7 @@ private[spark] class MapOutputTrackerMaster( private val mapOutputTrackerMasterMessages = new LinkedBlockingQueue[MapOutputTrackerMasterMessage] - private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf, isDriver = true) // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. @@ -1126,7 +1126,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mergeStatuses: Map[Int, Array[MergeStatus]] = new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala - private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf) + // This must be lazy to ensure that it is initialized when the first task is run and not at + // executor startup time. At startup time, user-added libraries may not have been + // downloaded to the executor, causing `isPushBasedShuffleEnabled` to fail when it tries to + // instantiate a serializer. See the followup to SPARK-36705 for more details. + private lazy val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf, isDriver = false) /** * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ee50a8f836..0388c7b576 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -272,33 +272,7 @@ object SparkEnv extends Logging { conf.set(DRIVER_PORT, rpcEnv.address.port) } - // Create an instance of the class with the given name, possibly initializing it with our conf - def instantiateClass[T](className: String): T = { - val cls = Utils.classForName(className) - // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just - // SparkConf, then one taking no arguments - try { - cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) - .newInstance(conf, java.lang.Boolean.valueOf(isDriver)) - .asInstanceOf[T] - } catch { - case _: NoSuchMethodException => - try { - cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] - } catch { - case _: NoSuchMethodException => - cls.getConstructor().newInstance().asInstanceOf[T] - } - } - } - - // Create an instance of the class named by the given SparkConf property - // if the property is not set, possibly initializing it with our conf - def instantiateClassFromConf[T](propertyName: ConfigEntry[String]): T = { - instantiateClass[T](conf.get(propertyName)) - } - - val serializer = instantiateClassFromConf[Serializer](SERIALIZER) + val serializer = Utils.instantiateSerializerFromConf[Serializer](SERIALIZER, conf, isDriver) logDebug(s"Using serializer: ${serializer.getClass}") val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) @@ -337,7 +311,8 @@ object SparkEnv extends Logging { val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER) val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) - val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) + val shuffleManager = Utils.instantiateSerializerOrShuffleManager[ShuffleManager]( + shuffleMgrClass, conf, isDriver) val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores) @@ -370,7 +345,7 @@ object SparkEnv extends Logging { } else { None }, blockManagerInfo, - mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])), + mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], isDriver)), registerOrLookupEndpoint( BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME, new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)), diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a3df49ae44..442edc732c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -254,7 +254,7 @@ private[spark] class DAGScheduler( private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf, isDriver = true) private val blockManagerMasterDriverHeartbeatTimeout = sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index bb260f89cd..50f9c8cef5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -24,7 +24,7 @@ import java.util.concurrent.ExecutorService import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv} +import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -463,7 +463,8 @@ private[spark] object ShuffleBlockPusher { private val BLOCK_PUSHER_POOL: ExecutorService = { val conf = SparkEnv.get.conf - if (Utils.isPushBasedShuffleEnabled(conf)) { + if (Utils.isPushBasedShuffleEnabled(conf, + isDriver = SparkContext.DRIVER_IDENTIFIER == SparkEnv.get.executorId)) { val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS) .getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1)) ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index cbb4e9c9ea..9ebf26b612 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -185,6 +185,7 @@ private[spark] class BlockManager( // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)` private[spark] val externalShuffleServiceEnabled: Boolean = externalBlockStoreClient.isDefined + private val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER private val remoteReadNioBufferConversion = conf.get(Network.NETWORK_REMOTE_READ_NIO_BUFFER_CONVERSION) @@ -194,8 +195,8 @@ private[spark] class BlockManager( val diskBlockManager = { // Only perform cleanup if an external service is not serving our shuffle files. val deleteFilesOnStop = - !externalShuffleServiceEnabled || executorId == SparkContext.DRIVER_IDENTIFIER - new DiskBlockManager(conf, deleteFilesOnStop) + !externalShuffleServiceEnabled || isDriver + new DiskBlockManager(conf, deleteFilesOnStop = deleteFilesOnStop, isDriver = isDriver) } // Visible for testing @@ -535,7 +536,7 @@ private[spark] class BlockManager( hostLocalDirManager = { if ((conf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) && !conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) || - Utils.isPushBasedShuffleEnabled(conf)) { + Utils.isPushBasedShuffleEnabled(conf, isDriver)) { Some(new HostLocalDirManager( futureExecutionContext, conf.get(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE), @@ -561,7 +562,7 @@ private[spark] class BlockManager( private def registerWithExternalShuffleServer(): Unit = { logInfo("Registering executor with local external shuffle service.") val shuffleManagerMeta = - if (Utils.isPushBasedShuffleEnabled(conf)) { + if (Utils.isPushBasedShuffleEnabled(conf, isDriver = isDriver, checkSerializer = false)) { s"${shuffleManager.getClass.getName}:" + s"${diskBlockManager.getMergeDirectoryAndAttemptIDJsonString()}}}" } else { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 6f043da76d..b96befce2c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -51,7 +51,8 @@ class BlockManagerMasterEndpoint( listenerBus: LiveListenerBus, externalBlockStoreClient: Option[ExternalBlockStoreClient], blockManagerInfo: mutable.Map[BlockManagerId, BlockManagerInfo], - mapOutputTracker: MapOutputTrackerMaster) + mapOutputTracker: MapOutputTrackerMaster, + isDriver: Boolean) extends IsolatedRpcEndpoint with Logging { // Mapping from executor id to the block manager's local disk directories. @@ -100,7 +101,7 @@ class BlockManagerMasterEndpoint( val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf) - private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf, isDriver) logInfo("BlockManagerMasterEndpoint up") // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index ee11e0e8ff..bebe32b952 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -45,7 +45,10 @@ import org.apache.spark.util.{ShutdownHookManager, Utils} * * ShuffleDataIO also can change the behavior of deleteFilesOnStop. */ -private[spark] class DiskBlockManager(conf: SparkConf, var deleteFilesOnStop: Boolean) +private[spark] class DiskBlockManager( + conf: SparkConf, + var deleteFilesOnStop: Boolean, + isDriver: Boolean) extends Logging { private[spark] val subDirsPerLocalDir = conf.get(config.DISKSTORE_SUB_DIRECTORIES) @@ -208,7 +211,7 @@ private[spark] class DiskBlockManager(conf: SparkConf, var deleteFilesOnStop: Bo * permission to create directories under application local directories. */ private def createLocalDirsForMergedShuffleBlocks(): Unit = { - if (Utils.isPushBasedShuffleEnabled(conf)) { + if (Utils.isPushBasedShuffleEnabled(conf, isDriver = isDriver, checkSerializer = false)) { // Will create the merge_manager directory only if it doesn't exist under the local dir. Utils.getConfiguredLocalDirs(conf).foreach { rootDir => try { diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 99138b670a..d83d9018ad 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -142,7 +142,7 @@ private class PushBasedFetchHelper( val mergedBlocksMetaListener = new MergedBlocksMetaListener { override def onSuccess(shuffleId: Int, shuffleMergeId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { - logInfo(s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + + logDebug(s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + s" $reduceId) from ${req.address.host}:${req.address.port}") try { iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, shuffleMergeId, diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f3fc90d061..0029bbd713 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2603,18 +2603,31 @@ private[spark] object Utils extends Logging { * - IO encryption disabled * - serializer(such as KryoSerializer) supports relocation of serialized objects */ - def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = { + def isPushBasedShuffleEnabled(conf: SparkConf, + isDriver: Boolean, + checkSerializer: Boolean = true): Boolean = { val pushBasedShuffleEnabled = conf.get(PUSH_BASED_SHUFFLE_ENABLED) if (pushBasedShuffleEnabled) { - val serializer = Utils.classForName(conf.get(SERIALIZER)).getConstructor(classOf[SparkConf]) - .newInstance(conf).asInstanceOf[Serializer] - val canDoPushBasedShuffle = conf.get(IS_TESTING).getOrElse(false) || - (conf.get(SHUFFLE_SERVICE_ENABLED) && - conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn" && - // TODO: [SPARK-36744] needs to support IO encryption for push-based shuffle - !conf.get(IO_ENCRYPTION_ENABLED) && - serializer.supportsRelocationOfSerializedObjects) - + val canDoPushBasedShuffle = { + val isTesting = conf.get(IS_TESTING).getOrElse(false) + val isShuffleServiceAndYarn = conf.get(SHUFFLE_SERVICE_ENABLED) && + conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn" + lazy val serializerIsSupported = { + if (checkSerializer) { + Option(SparkEnv.get) + .map(_.serializer) + .filter(_ != null) + .getOrElse(instantiateSerializerFromConf[Serializer](SERIALIZER, conf, isDriver)) + .supportsRelocationOfSerializedObjects + } else { + // if no need to check Serializer, always set serializerIsSupported as true + true + } + } + // TODO: [SPARK-36744] needs to support IO encryption for push-based shuffle + val ioEncryptionDisabled = !conf.get(IO_ENCRYPTION_ENABLED) + (isShuffleServiceAndYarn || isTesting) && ioEncryptionDisabled && serializerIsSupported + } if (!canDoPushBasedShuffle) { logWarning("Push-based shuffle can only be enabled when the application is submitted " + "to run in YARN mode, with external shuffle service enabled, IO encryption disabled, " + @@ -2627,6 +2640,38 @@ private[spark] object Utils extends Logging { } } + // Create an instance of Serializer or ShuffleManager with the given name, + // possibly initializing it with our conf + def instantiateSerializerOrShuffleManager[T](className: String, + conf: SparkConf, + isDriver: Boolean): T = { + val cls = Utils.classForName(className) + // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just + // SparkConf, then one taking no arguments + try { + cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) + .newInstance(conf, java.lang.Boolean.valueOf(isDriver)) + .asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } + } + } + + // Create an instance of Serializer named by the given SparkConf property + // if the property is not set, possibly initializing it with our conf + def instantiateSerializerFromConf[T](propertyName: ConfigEntry[String], + conf: SparkConf, + isDriver: Boolean): T = { + instantiateSerializerOrShuffleManager[T]( + conf.get(propertyName), conf, isDriver) + } + /** * Return whether dynamic allocation is enabled in the given conf. */ diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index e81196f8ea..4051118572 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -337,6 +337,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { test("SPARK-32921: master register and unregister merge result") { conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) conf.set(IS_TESTING, true) + conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") val rpcEnv = createRpcEnv("test") val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, @@ -596,6 +597,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { newConf.set(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST, 10240L) // 10 KiB << 1MiB framesize newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true) newConf.set(IS_TESTING, true) + newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") // needs TorrentBroadcast so need a SparkContext withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index deddaea4df..4cb64ed0c2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3431,6 +3431,10 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti conf.set("spark.master", "pushbasedshuffleclustermanager") // Needed to run push-based shuffle tests in ad-hoc manner through IDE conf.set(Tests.IS_TESTING, true) + // [SPARK-36705] Push-based shuffle does not work with Spark's default + // JavaSerializer and will be disabled with it, as it does not support + // object relocation + conf.set(config.SERIALIZER, "org.apache.spark.serializer.KryoSerializer") } test("SPARK-32920: shuffle merge finalization") { diff --git a/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala index 33f544a391..4e74036e11 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala @@ -139,6 +139,7 @@ class HostLocalShuffleReadingSuite extends SparkFunSuite with Matchers with Loca .set(SHUFFLE_SERVICE_ENABLED, true) .set("spark.yarn.maxAttempts", "1") .set(PUSH_BASED_SHUFFLE_ENABLED, true) + .set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") sc = new SparkContext("local-cluster[2, 1, 1024]", "test-host-local-shuffle-reading", conf) sc.env.blockManager.hostLocalDirManager.isDefined should equal(true) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 495747b2c7..fc7b7a4406 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -102,7 +102,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker)), + new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker, isDriver = true)), rpcEnv.setupEndpoint("blockmanagerHeartbeat", new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true) allStores.clear() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 173b839dd5..2cb281d468 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -98,6 +98,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE .set(IS_TESTING, true) .set(MEMORY_FRACTION, 1.0) .set(MEMORY_STORAGE_FRACTION, 0.999) + .set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") .set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L) .set(Network.RPC_ASK_TIMEOUT, "5s") @@ -185,7 +186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE liveListenerBus = spy(new LiveListenerBus(conf)) master = spy(new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - liveListenerBus, None, blockManagerInfo, mapOutputTracker)), + liveListenerBus, None, blockManagerInfo, mapOutputTracker, isDriver = true)), rpcEnv.setupEndpoint("blockmanagerHeartbeat", new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true)) } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 0443c40bce..b36eeb767e 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -60,7 +60,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B super.beforeEach() val conf = testConf.clone conf.set("spark.local.dir", rootDirs) - diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false) } override def afterEach(): Unit = { @@ -105,7 +105,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B testConf.set("spark.local.dir", rootDirs) testConf.set("spark.shuffle.push.enabled", "true") testConf.set(config.Tests.IS_TESTING, true) - diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true) + diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true, isDriver = false) assert(Utils.getConfiguredLocalDirs(testConf).map( rootDir => new File(rootDir, DiskBlockManager.MERGE_DIRECTORY)) .filter(mergeDir => mergeDir.exists()).length === 2) @@ -118,7 +118,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B test("Test dir creation with permission 770") { val testDir = new File("target/testDir"); FileUtils.deleteQuietly(testDir) - diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true) + diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true, isDriver = false) diskBlockManager.createDirWithPermission770(testDir) assert(testDir.exists && testDir.isDirectory) val permission = PosixFilePermissions.toString( @@ -129,7 +129,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B test("Encode merged directory name and attemptId in shuffleManager field") { testConf.set(config.APP_ATTEMPT_ID, "1"); - diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true) + diskBlockManager = new DiskBlockManager(testConf, deleteFilesOnStop = true, isDriver = false) val mergedShuffleMeta = diskBlockManager.getMergeDirectoryAndAttemptIDJsonString(); val mapper: ObjectMapper = new ObjectMapper val typeRef: TypeReference[HashMap[String, String]] = diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 97b9c973e9..be1b9be2d8 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -46,7 +46,7 @@ class DiskStoreSuite extends SparkFunSuite { val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes)) val blockId = BlockId("rdd_1_2") - val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false) val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, securityManager) @@ -77,7 +77,7 @@ class DiskStoreSuite extends SparkFunSuite { test("block size tracking") { val conf = new SparkConf() - val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false) val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) val blockId = BlockId("rdd_1_2") @@ -96,7 +96,7 @@ class DiskStoreSuite extends SparkFunSuite { test("blocks larger than 2gb") { val conf = new SparkConf() .set(config.MEMORY_MAP_LIMIT_FOR_TESTS.key, "10k") - val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false) val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) val blockId = BlockId("rdd_1_2") @@ -137,7 +137,7 @@ class DiskStoreSuite extends SparkFunSuite { val conf = new SparkConf() val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) - val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true, isDriver = false) val diskStore = new DiskStore(conf, diskBlockManager, securityManager) val blockId = BlockId("rdd_1_2") diff --git a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala index f58d8ce3ba..88197b6c5d 100644 --- a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala @@ -68,7 +68,7 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext { val bmm = new BlockManagerMaster(new NoopRpcEndpointRef(conf), null, conf, false) val bm = mock(classOf[BlockManager]) - val dbm = new DiskBlockManager(conf, false) + val dbm = new DiskBlockManager(conf, deleteFilesOnStop = false, isDriver = false) when(bm.diskBlockManager).thenReturn(dbm) when(bm.master).thenReturn(bmm) val resolver = new IndexShuffleBlockResolver(conf, bm) @@ -134,7 +134,7 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext { val ids = Set((1, 1L, 1)) val bm = mock(classOf[BlockManager]) - val dbm = new DiskBlockManager(conf, false) + val dbm = new DiskBlockManager(conf, deleteFilesOnStop = false, isDriver = false) when(bm.diskBlockManager).thenReturn(dbm) val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf, bm) val indexFile = indexShuffleBlockResolver.getIndexFile(1, 1L) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index f8607f1b90..05b24ec5a1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1503,23 +1503,26 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("isPushBasedShuffleEnabled when PUSH_BASED_SHUFFLE_ENABLED " + "and SHUFFLE_SERVICE_ENABLED are both set to true in YARN mode with maxAttempts set to 1") { val conf = new SparkConf() - assert(Utils.isPushBasedShuffleEnabled(conf) === false) + assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false) conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) conf.set(IS_TESTING, false) - assert(Utils.isPushBasedShuffleEnabled(conf) === false) + assert(Utils.isPushBasedShuffleEnabled( + conf, isDriver = false, checkSerializer = false) === false) conf.set(SHUFFLE_SERVICE_ENABLED, true) conf.set(SparkLauncher.SPARK_MASTER, "yarn") conf.set("spark.yarn.maxAppAttempts", "1") conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") - assert(Utils.isPushBasedShuffleEnabled(conf) === true) + assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === true) conf.set("spark.yarn.maxAppAttempts", "2") - assert(Utils.isPushBasedShuffleEnabled(conf) === true) + assert(Utils.isPushBasedShuffleEnabled( + conf, isDriver = false, checkSerializer = false) === true) conf.set(IO_ENCRYPTION_ENABLED, true) - assert(Utils.isPushBasedShuffleEnabled(conf) === false) + assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false) conf.set(IO_ENCRYPTION_ENABLED, false) - assert(Utils.isPushBasedShuffleEnabled(conf) === true) + assert(Utils.isPushBasedShuffleEnabled( + conf, isDriver = false, checkSerializer = false) === true) conf.set(SERIALIZER, "org.apache.spark.serializer.JavaSerializer") - assert(Utils.isPushBasedShuffleEnabled(conf) === false) + assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 425e39c598..3bcea1ab2c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -93,7 +93,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker)), + new LiveListenerBus(conf), None, blockManagerInfo, mapOutputTracker, isDriver = true)), rpcEnv.setupEndpoint("blockmanagerHeartbeat", new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true)