diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 3f4d236571..ab08698035 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -220,12 +220,6 @@ private[yarn] class ExecutorRunnable( val env = new HashMap[String, String]() Client.populateClasspath(null, conf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH)) - sparkConf.getExecutorEnv.foreach { case (key, value) => - // This assumes each executor environment variable set here is a path - // This is kept for backward compatibility and consistency with hadoop - YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, @@ -233,6 +227,20 @@ private[yarn] class ExecutorRunnable( ) val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } + + sparkConf.getExecutorEnv.foreach { case (key, value) => + if (key == Environment.CLASSPATH.name()) { + // If the key of env variable is CLASSPATH, we assume it is a path and append it. + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) + } else { + // For other env variables, simply overwrite the value. + env(key) = value + } + } + // Add log urls container.foreach { c => sys.env.get("SPARK_USER").foreach { user => @@ -245,8 +253,6 @@ private[yarn] class ExecutorRunnable( } } - System.getenv().asScala.filterKeys(_.startsWith("SPARK")) - .foreach { case (k, v) => env(k) = v } env } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 33d400a5b1..a129be7c06 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -225,6 +225,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { finalState should be (SparkAppHandle.State.FAILED) } + test("executor env overwrite AM env in client mode") { + testExecutorEnv(true) + } + + test("executor env overwrite AM env in cluster mode") { + testExecutorEnv(false) + } + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), @@ -305,6 +313,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { checkResult(finalState, executorResult, "OVERRIDDEN") } + private def testExecutorEnv(clientMode: Boolean): Unit = { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(ExecutorEnvTestApp.getClass), + appArgs = Seq(result.getAbsolutePath), + extraConf = Map( + "spark.yarn.appMasterEnv.TEST_ENV" -> "am_val", + "spark.executorEnv.TEST_ENV" -> "executor_val" + ) + ) + checkResult(finalState, result, "true") + } } private[spark] class SaveExecutorInfo extends SparkListener { @@ -526,3 +545,20 @@ private object SparkContextTimeoutApp { } } + +private object ExecutorEnvTestApp { + + def main(args: Array[String]): Unit = { + val status = args(0) + val sparkConf = new SparkConf() + val sc = new SparkContext(sparkConf) + val executorEnvs = sc.parallelize(Seq(1)).flatMap { _ => sys.env }.collect().toMap + val result = sparkConf.getExecutorEnv.forall { case (k, v) => + executorEnvs.get(k).contains(v) + } + + Files.write(result.toString, new File(status), StandardCharsets.UTF_8) + sc.stop() + } + +}