diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6fade10b7a..60d04046c7 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -222,6 +222,9 @@ private[spark] class ExecutorAllocationManager( throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } + if (tasksPerExecutorForFullParallelism == 0) { + throw new SparkException(s"${EXECUTOR_CORES.key} must not be < ${CPUS_PER_TASK.key}.") + } if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { throw new SparkException( diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 30d8aa4e50..7050396e84 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -577,6 +577,16 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } + if (contains(EXECUTOR_CORES) && contains(CPUS_PER_TASK)) { + val executorCores = get(EXECUTOR_CORES) + val taskCpus = get(CPUS_PER_TASK) + + if (executorCores < taskCpus) { + throw new SparkException( + s"${EXECUTOR_CORES.key} must not be less than ${CPUS_PER_TASK.key}.") + } + } + val encryptionEnabled = get(NETWORK_CRYPTO_ENABLED) || get(SASL_ENCRYPTION_ENABLED) require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 71604667bc..4abb18d4aa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2665,25 +2665,8 @@ object SparkContext extends Logging { // When running locally, don't try to re-execute tasks on failure. val MAX_LOCAL_TASK_FAILURES = 1 - // SPARK-26340: Ensure that executor's core num meets at least one task requirement. - def checkCpusPerTask( - executorCoreNum: Int = sc.conf.get(EXECUTOR_CORES), - clusterMode: Boolean = true): Unit = { - val cpusPerTask = sc.conf.get(CPUS_PER_TASK) - if (executorCoreNum < cpusPerTask) { - val message = if (clusterMode) { - s"${CPUS_PER_TASK.key} must be <= ${EXECUTOR_CORES.key} when run on $master." - } else { - s"Only $executorCoreNum cores available per executor when run on $master," + - s" and ${CPUS_PER_TASK.key} must be <= it." - } - throw new SparkException(message) - } - } - master match { case "local" => - checkCpusPerTask(1, clusterMode = false) val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) @@ -2696,7 +2679,6 @@ object SparkContext extends Logging { if (threadCount <= 0) { throw new SparkException(s"Asked to run locally with $threadCount threads") } - checkCpusPerTask(threadCount, clusterMode = false) val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) @@ -2707,14 +2689,12 @@ object SparkContext extends Logging { // local[*, M] means the number of cores on the computer with M failures // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt - checkCpusPerTask(threadCount, clusterMode = false) val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) case SPARK_REGEX(sparkUrl) => - checkCpusPerTask() val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls) @@ -2722,7 +2702,6 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - checkCpusPerTask() // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. val memoryPerSlaveInt = memoryPerSlave.toInt if (sc.executorMemory > memoryPerSlaveInt) { @@ -2743,7 +2722,6 @@ object SparkContext extends Logging { (backend, scheduler) case masterUrl => - checkCpusPerTask() val cm = getClusterManager(masterUrl) match { case Some(clusterMgr) => clusterMgr case None => throw new SparkException("Could not parse Master URL: '" + master + "'") diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 8a57b90a76..5ca4f9c73f 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -140,6 +140,13 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst assert(sc.appName === "My other app") } + test("creating SparkContext with cpus per tasks bigger than cores per executors") { + val conf = new SparkConf(false) + .set(EXECUTOR_CORES, 1) + .set(CPUS_PER_TASK, 2) + intercept[SparkException] { sc = new SparkContext(conf) } + } + test("nested property names") { // This wasn't supported by some external conf parsing libraries System.setProperty("spark.test.a", "a") diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 3490eaf550..7a16f7b715 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -710,27 +710,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(runningTaskIds.isEmpty) } } - - test(s"Avoid setting ${CPUS_PER_TASK.key} unreasonably (SPARK-27192)") { - val FAIL_REASON = s"${CPUS_PER_TASK.key} must be <=" - Seq( - ("local", 2, None), - ("local[2]", 3, None), - ("local[2, 1]", 3, None), - ("spark://test-spark-cluster", 2, Option(1)), - ("local-cluster[1, 1, 1000]", 2, Option(1)), - ("yarn", 2, Option(1)) - ).foreach { case (master, cpusPerTask, executorCores) => - val conf = new SparkConf() - conf.set(CPUS_PER_TASK, cpusPerTask) - executorCores.map(executorCores => conf.set(EXECUTOR_CORES, executorCores)) - val ex = intercept[SparkException] { - sc = new SparkContext(master, "test", conf) - } - assert(ex.getMessage.contains(FAIL_REASON)) - resetSparkContext() - } - } } object SparkContextSuite {