diff --git a/core/src/main/java/org/apache/spark/api/resource/ResourceDiscoveryPlugin.java b/core/src/main/java/org/apache/spark/api/resource/ResourceDiscoveryPlugin.java
new file mode 100644
index 0000000000..ffd2f83552
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/resource/ResourceDiscoveryPlugin.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.resource;
+
+import java.util.Optional;
+
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.SparkConf;
+import org.apache.spark.resource.ResourceInformation;
+import org.apache.spark.resource.ResourceRequest;
+
+/**
+ * :: DeveloperApi ::
+ * A plugin that can be dynamically loaded into a Spark application to control how custom
+ * resources are discovered. Plugins can be chained to allow different plugins to handle
+ * different resource types.
+ *
+ * Plugins must implement the function discoveryResource.
+ *
+ * @since 3.0.0
+ */
+@DeveloperApi
+public interface ResourceDiscoveryPlugin {
+ /**
+ * Discover the addresses of the requested resource.
+ *
+ * This method is called early in the initialization of the Spark Executor/Driver/Worker.
+ * This function is responsible for discovering the addresses of the resource which Spark will
+ * then use for scheduling and eventually providing to the user.
+ * Depending on the deployment mode and and configuration of custom resources, this could be
+ * called by the Spark Driver, the Spark Executors, in standalone mode the Workers, or all of
+ * them. The ResourceRequest has a ResourceID component that can be used to distinguish which
+ * component it is called from and what resource its being called for.
+ * This will get called once for each resource type requested and its the responsibility of
+ * this function to return enough addresses of that resource based on the request. If
+ * the addresses do not meet the requested amount, Spark will fail.
+ * If this plugin doesn't handle a particular resource, it should return an empty Optional
+ * and Spark will try other plugins and then last fall back to the default discovery script
+ * plugin.
+ *
+ * @param request The ResourceRequest that to be discovered.
+ * @param sparkConf SparkConf
+ * @return An {@link Optional} containing a {@link ResourceInformation} object containing
+ * the resource name and the addresses of the resource. If it returns {@link Optional#EMPTY}
+ * other plugins will be called.
+ */
+ Optional discoverResource(ResourceRequest request, SparkConf sparkConf);
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6e0c7acf8b..91188d58f4 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -2806,17 +2806,17 @@ object SparkContext extends Logging {
// Make sure the executor resources were specified through config.
val execAmount = executorResourcesAndAmounts.getOrElse(taskReq.resourceName,
throw new SparkException("The executor resource config: " +
- ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName).amountConf +
+ new ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName).amountConf +
" needs to be specified since a task requirement config: " +
- ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName).amountConf +
+ new ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName).amountConf +
" was specified")
)
// Make sure the executor resources are large enough to launch at least one task.
if (execAmount < taskReq.amount) {
throw new SparkException("The executor resource config: " +
- ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName).amountConf +
+ new ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName).amountConf +
s" = $execAmount has to be >= the requested amount in task resource config: " +
- ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName).amountConf +
+ new ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName).amountConf +
s" = ${taskReq.amount}")
}
// Compare and update the max slots each executor can provide.
diff --git a/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala b/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala
index d6f9618af4..65bf4351eb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala
@@ -208,7 +208,7 @@ private[spark] object StandaloneResourceUtils extends Logging {
}
val newAllocation = {
val allocations = newAssignments.map { case (rName, addresses) =>
- ResourceAllocation(ResourceID(componentName, rName), addresses)
+ ResourceAllocation(new ResourceID(componentName, rName), addresses)
}.toSeq
StandaloneResourceAllocation(pid, allocations)
}
@@ -348,7 +348,7 @@ private[spark] object StandaloneResourceUtils extends Logging {
val compShortName = componentName.substring(componentName.lastIndexOf(".") + 1)
val tmpFile = Utils.tempFileWith(dir)
val allocations = resources.map { case (rName, rInfo) =>
- ResourceAllocation(ResourceID(componentName, rName), rInfo.addresses)
+ ResourceAllocation(new ResourceID(componentName, rName), rInfo.addresses)
}.toSeq
try {
writeResourceAllocationJson(componentName, allocations, tmpFile)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index ce211ce8dd..25c5b9812f 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -17,6 +17,7 @@
package org.apache.spark.executor
+import java.io.File
import java.net.URL
import java.nio.ByteBuffer
import java.util.Locale
@@ -42,7 +43,7 @@ import org.apache.spark.rpc._
import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.serializer.SerializerInstance
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, ThreadUtils, Utils}
private[spark] class CoarseGrainedExecutorBackend(
override val rpcEnv: RpcEnv,
@@ -99,15 +100,36 @@ private[spark] class CoarseGrainedExecutorBackend(
}(ThreadUtils.sameThread)
}
+ /**
+ * Create a classLoader for use for resource discovery. The user could provide a class
+ * as a substitute for the default one so we have to be able to load it from a user specified
+ * jar.
+ */
+ private def createClassLoader(): MutableURLClassLoader = {
+ val currentLoader = Utils.getContextOrSparkClassLoader
+ val urls = userClassPath.toArray
+ if (env.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)) {
+ new ChildFirstURLClassLoader(urls, currentLoader)
+ } else {
+ new MutableURLClassLoader(urls, currentLoader)
+ }
+ }
+
// visible for testing
def parseOrFindResources(resourcesFileOpt: Option[String]): Map[String, ResourceInformation] = {
+ // use a classloader that includes the user classpath in case they specified a class for
+ // resource discovery
+ val urlClassLoader = createClassLoader()
logDebug(s"Resource profile id is: ${resourceProfile.id}")
- val resources = getOrDiscoverAllResourcesForResourceProfile(
- resourcesFileOpt,
- SPARK_EXECUTOR_PREFIX,
- resourceProfile)
- logResourceInfo(SPARK_EXECUTOR_PREFIX, resources)
- resources
+ Utils.withContextClassLoader(urlClassLoader) {
+ val resources = getOrDiscoverAllResourcesForResourceProfile(
+ resourcesFileOpt,
+ SPARK_EXECUTOR_PREFIX,
+ resourceProfile,
+ env.conf)
+ logResourceInfo(SPARK_EXECUTOR_PREFIX, resources)
+ resources
+ }
}
def extractLogUrls: Map[String, String] = {
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index e68368f37a..f91f31be2f 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -54,6 +54,18 @@ package object config {
.stringConf
.createOptional
+ private[spark] val RESOURCES_DISCOVERY_PLUGIN =
+ ConfigBuilder("spark.resources.discovery.plugin")
+ .doc("Comma-separated list of class names implementing" +
+ "org.apache.spark.api.resource.ResourceDiscoveryPlugin to load into the application." +
+ "This is for advanced users to replace the resource discovery class with a " +
+ "custom implementation. Spark will try each class specified until one of them " +
+ "returns the resource information for that resource. It tries the discovery " +
+ "script last if none of the plugins return information for that resource.")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
private[spark] val DRIVER_RESOURCES_FILE =
ConfigBuilder("spark.driver.resourcesFile")
.internal()
diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceDiscoveryScriptPlugin.scala b/core/src/main/scala/org/apache/spark/resource/ResourceDiscoveryScriptPlugin.scala
new file mode 100644
index 0000000000..2ac6d3c500
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/resource/ResourceDiscoveryScriptPlugin.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.resource
+
+import java.io.File
+import java.util.Optional
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.api.resource.ResourceDiscoveryPlugin
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils.executeAndGetOutput
+
+/**
+ * The default plugin that is loaded into a Spark application to control how custom
+ * resources are discovered. This executes the discovery script specified by the user
+ * and gets the json output back and contructs ResourceInformation objects from that.
+ * If the user specifies custom plugins, this is the last one to be executed and
+ * throws if the resource isn't discovered.
+ */
+class ResourceDiscoveryScriptPlugin extends ResourceDiscoveryPlugin with Logging {
+ override def discoverResource(
+ request: ResourceRequest,
+ sparkConf: SparkConf): Optional[ResourceInformation] = {
+ val script = request.discoveryScript
+ val resourceName = request.id.resourceName
+ val result = if (script.isPresent) {
+ val scriptFile = new File(script.get)
+ logInfo(s"Discovering resources for $resourceName with script: $scriptFile")
+ // check that script exists and try to execute
+ if (scriptFile.exists()) {
+ val output = executeAndGetOutput(Seq(script.get), new File("."))
+ ResourceInformation.parseJson(output)
+ } else {
+ throw new SparkException(s"Resource script: $scriptFile to discover $resourceName " +
+ "doesn't exist!")
+ }
+ } else {
+ throw new SparkException(s"User is expecting to use resource: $resourceName, but " +
+ "didn't specify a discovery script!")
+ }
+ if (!result.name.equals(resourceName)) {
+ throw new SparkException(s"Error running the resource discovery script ${script.get}: " +
+ s"script returned resource name ${result.name} and we were expecting $resourceName.")
+ }
+ Optional.of(result)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala
index eb713a27be..14019d27fc 100644
--- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala
+++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala
@@ -149,8 +149,8 @@ object ResourceProfile extends Logging {
val execReq = ResourceUtils.parseAllResourceRequests(conf, SPARK_EXECUTOR_PREFIX)
execReq.foreach { req =>
val name = req.id.resourceName
- ereqs.resource(name, req.amount, req.discoveryScript.getOrElse(""),
- req.vendor.getOrElse(""))
+ ereqs.resource(name, req.amount, req.discoveryScript.orElse(""),
+ req.vendor.orElse(""))
}
ereqs.requests
}
diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala
index 190b0cdc88..7dd7fc1b99 100644
--- a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala
+++ b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala
@@ -17,8 +17,8 @@
package org.apache.spark.resource
-import java.io.File
import java.nio.file.{Files, Paths}
+import java.util.Optional
import scala.util.control.NonFatal
@@ -26,39 +26,75 @@ import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.resource.ResourceDiscoveryPlugin
import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.SPARK_TASK_PREFIX
-import org.apache.spark.util.Utils.executeAndGetOutput
+import org.apache.spark.internal.config.{RESOURCES_DISCOVERY_PLUGIN, SPARK_TASK_PREFIX}
+import org.apache.spark.util.Utils
/**
* Resource identifier.
* @param componentName spark.driver / spark.executor / spark.task
* @param resourceName gpu, fpga, etc
+ *
+ * @since 3.0.0
*/
-private[spark] case class ResourceID(componentName: String, resourceName: String) {
- def confPrefix: String = s"$componentName.${ResourceUtils.RESOURCE_PREFIX}.$resourceName."
- def amountConf: String = s"$confPrefix${ResourceUtils.AMOUNT}"
- def discoveryScriptConf: String = s"$confPrefix${ResourceUtils.DISCOVERY_SCRIPT}"
- def vendorConf: String = s"$confPrefix${ResourceUtils.VENDOR}"
+@DeveloperApi
+class ResourceID(val componentName: String, val resourceName: String) {
+ private[spark] def confPrefix: String = {
+ s"$componentName.${ResourceUtils.RESOURCE_PREFIX}.$resourceName."
+ }
+ private[spark] def amountConf: String = s"$confPrefix${ResourceUtils.AMOUNT}"
+ private[spark] def discoveryScriptConf: String = s"$confPrefix${ResourceUtils.DISCOVERY_SCRIPT}"
+ private[spark] def vendorConf: String = s"$confPrefix${ResourceUtils.VENDOR}"
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case that: ResourceID =>
+ that.getClass == this.getClass &&
+ that.componentName == componentName && that.resourceName == resourceName
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode(): Int = Seq(componentName, resourceName).hashCode()
}
/**
- * Case class that represents a resource request at the executor level.
+ * Class that represents a resource request.
*
* The class used when discovering resources (using the discovery script),
- * or via the context as it is parsing configuration, for SPARK_EXECUTOR_PREFIX.
+ * or via the context as it is parsing configuration for the ResourceID.
*
* @param id object identifying the resource
* @param amount integer amount for the resource. Note that for a request (executor level),
* fractional resources does not make sense, so amount is an integer.
* @param discoveryScript optional discovery script file name
* @param vendor optional vendor name
+ *
+ * @since 3.0.0
*/
-private[spark] case class ResourceRequest(
- id: ResourceID,
- amount: Int,
- discoveryScript: Option[String],
- vendor: Option[String])
+@DeveloperApi
+class ResourceRequest(
+ val id: ResourceID,
+ val amount: Long,
+ val discoveryScript: Optional[String],
+ val vendor: Optional[String]) {
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case that: ResourceRequest =>
+ that.getClass == this.getClass &&
+ that.id == id && that.amount == amount && discoveryScript == discoveryScript &&
+ vendor == vendor
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode(): Int = Seq(id, amount, discoveryScript, vendor).hashCode()
+}
/**
* Case class that represents resource requirements for a component in a
@@ -105,15 +141,15 @@ private[spark] object ResourceUtils extends Logging {
val amount = settings.getOrElse(AMOUNT,
throw new SparkException(s"You must specify an amount for ${resourceId.resourceName}")
).toInt
- val discoveryScript = settings.get(DISCOVERY_SCRIPT)
- val vendor = settings.get(VENDOR)
- ResourceRequest(resourceId, amount, discoveryScript, vendor)
+ val discoveryScript = Optional.ofNullable(settings.get(DISCOVERY_SCRIPT).orNull)
+ val vendor = Optional.ofNullable(settings.get(VENDOR).orNull)
+ new ResourceRequest(resourceId, amount, discoveryScript, vendor)
}
def listResourceIds(sparkConf: SparkConf, componentName: String): Seq[ResourceID] = {
sparkConf.getAllWithPrefix(s"$componentName.$RESOURCE_PREFIX.").map { case (key, _) =>
key.substring(0, key.indexOf('.'))
- }.toSet.toSeq.map(name => ResourceID(componentName, name))
+ }.toSet.toSeq.map(name => new ResourceID(componentName, name))
}
def parseAllResourceRequests(
@@ -218,7 +254,7 @@ private[spark] object ResourceUtils extends Logging {
val otherResources = otherResourceIds.flatMap { id =>
val request = parseResourceRequest(sparkConf, id)
if (request.amount > 0) {
- Some(ResourceAllocation(id, discoverResource(request).addresses))
+ Some(ResourceAllocation(id, discoverResource(sparkConf, request).addresses))
} else {
None
}
@@ -274,6 +310,15 @@ private[spark] object ResourceUtils extends Logging {
resourceInfoMap
}
+ // create an empty Optional if the string is empty
+ private def emptyStringToOptional(optStr: String): Optional[String] = {
+ if (optStr.isEmpty) {
+ Optional.empty[String]
+ } else {
+ Optional.of(optStr)
+ }
+ }
+
/**
* This function is similar to getOrDiscoverallResources, except for it uses the ResourceProfile
* information instead of the application level configs.
@@ -290,14 +335,19 @@ private[spark] object ResourceUtils extends Logging {
def getOrDiscoverAllResourcesForResourceProfile(
resourcesFileOpt: Option[String],
componentName: String,
- resourceProfile: ResourceProfile): Map[String, ResourceInformation] = {
+ resourceProfile: ResourceProfile,
+ sparkConf: SparkConf): Map[String, ResourceInformation] = {
val fileAllocated = parseAllocated(resourcesFileOpt, componentName)
val fileAllocResMap = fileAllocated.map(a => (a.id.resourceName, a.toResourceInformation)).toMap
// only want to look at the ResourceProfile for resources not in the resources file
val execReq = ResourceProfile.getCustomExecutorResources(resourceProfile)
val filteredExecreq = execReq.filterNot { case (rname, _) => fileAllocResMap.contains(rname) }
val rpAllocations = filteredExecreq.map { case (rName, execRequest) =>
- val addrs = discoverResource(rName, Option(execRequest.discoveryScript)).addresses
+ val resourceId = new ResourceID(componentName, rName)
+ val scriptOpt = emptyStringToOptional(execRequest.discoveryScript)
+ val vendorOpt = emptyStringToOptional(execRequest.vendor)
+ val resourceReq = new ResourceRequest(resourceId, execRequest.amount, scriptOpt, vendorOpt)
+ val addrs = discoverResource(sparkConf, resourceReq).addresses
(rName, new ResourceInformation(rName, addrs))
}
val allAllocations = fileAllocResMap ++ rpAllocations
@@ -312,36 +362,24 @@ private[spark] object ResourceUtils extends Logging {
logInfo("==============================================================")
}
- // visible for test
private[spark] def discoverResource(
- resourceName: String,
- script: Option[String]): ResourceInformation = {
- val result = if (script.nonEmpty) {
- val scriptFile = new File(script.get)
- // check that script exists and try to execute
- if (scriptFile.exists()) {
- val output = executeAndGetOutput(Seq(script.get), new File("."))
- ResourceInformation.parseJson(output)
- } else {
- throw new SparkException(s"Resource script: $scriptFile to discover $resourceName " +
- "doesn't exist!")
+ sparkConf: SparkConf,
+ resourceRequest: ResourceRequest): ResourceInformation = {
+ // always put the discovery script plugin as last plugin
+ val discoveryScriptPlugin = "org.apache.spark.resource.ResourceDiscoveryScriptPlugin"
+ val pluginClasses = sparkConf.get(RESOURCES_DISCOVERY_PLUGIN) :+ discoveryScriptPlugin
+ val resourcePlugins = Utils.loadExtensions(classOf[ResourceDiscoveryPlugin], pluginClasses,
+ sparkConf)
+ // apply each plugin until one of them returns the information for this resource
+ var riOption: Optional[ResourceInformation] = Optional.empty()
+ resourcePlugins.foreach { plugin =>
+ val riOption = plugin.discoverResource(resourceRequest, sparkConf)
+ if (riOption.isPresent()) {
+ return riOption.get()
}
- } else {
- throw new SparkException(s"User is expecting to use resource: $resourceName, but " +
- "didn't specify a discovery script!")
}
- if (!result.name.equals(resourceName)) {
- throw new SparkException(s"Error running the resource discovery script ${script.get}: " +
- s"script returned resource name ${result.name} and we were expecting $resourceName.")
- }
- result
- }
-
- // visible for test
- private[spark] def discoverResource(resourceRequest: ResourceRequest): ResourceInformation = {
- val resourceName = resourceRequest.id.resourceName
- val script = resourceRequest.discoveryScript
- discoverResource(resourceName, script)
+ throw new SparkException(s"None of the discovery plugins returned ResourceInformation for " +
+ s"${resourceRequest.id.resourceName}")
}
// known types of resources
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index c210eb0d60..3bc2061c4f 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -449,7 +449,7 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
conf.remove(TASK_FPGA_ID.amountConf)
// Ignore invalid prefix
- conf.set(ResourceID("spark.invalid.prefix", FPGA).amountConf, "1")
+ conf.set(new ResourceID("spark.invalid.prefix", FPGA).amountConf, "1")
taskResourceRequirement =
parseResourceRequirements(conf, SPARK_TASK_PREFIX)
.map(req => (req.resourceName, req.amount)).toMap
diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
index a996fc4a0b..3134a738b3 100644
--- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
@@ -164,7 +164,8 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
val parsedResources = backend.parseOrFindResources(Some(f1))
}.getMessage()
- assert(error.contains("Resource script: to discover gpu doesn't exist!"))
+ assert(error.contains("User is expecting to use resource: gpu, but didn't " +
+ "specify a discovery script!"))
}
}
diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala
new file mode 100644
index 0000000000..7a05daa2ad
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.resource
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+import java.util.Optional
+import java.util.UUID
+
+import scala.concurrent.duration._
+
+import com.google.common.io.Files
+import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
+
+import org.apache.spark._
+import org.apache.spark.TestUtils.createTempScriptWithExpectedOutput
+import org.apache.spark.api.resource.ResourceDiscoveryPlugin
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.resource.ResourceUtils.{FPGA, GPU}
+import org.apache.spark.resource.TestResourceIDs._
+import org.apache.spark.util.Utils
+
+class ResourceDiscoveryPluginSuite extends SparkFunSuite with LocalSparkContext {
+
+ test("plugin initialization in non-local mode fpga and gpu") {
+ assume(!(Utils.isWindows))
+ withTempDir { dir =>
+ val conf = new SparkConf()
+ .setAppName(getClass().getName())
+ .set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]")
+ .set(RESOURCES_DISCOVERY_PLUGIN, Seq(classOf[TestResourceDiscoveryPluginGPU].getName(),
+ classOf[TestResourceDiscoveryPluginFPGA].getName()))
+ .set(TestResourceDiscoveryPlugin.TEST_PATH_CONF, dir.getAbsolutePath())
+ .set(WORKER_GPU_ID.amountConf, "2")
+ .set(TASK_GPU_ID.amountConf, "1")
+ .set(EXECUTOR_GPU_ID.amountConf, "1")
+ .set(SPARK_RESOURCES_DIR, dir.getName())
+ .set(WORKER_FPGA_ID.amountConf, "2")
+ .set(TASK_FPGA_ID.amountConf, "1")
+ .set(EXECUTOR_FPGA_ID.amountConf, "1")
+
+ sc = new SparkContext(conf)
+ TestUtils.waitUntilExecutorsUp(sc, 2, 10000)
+
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ val children = dir.listFiles()
+ assert(children != null)
+ assert(children.length >= 4)
+ val gpuFiles = children.filter(f => f.getName().contains(GPU))
+ val fpgaFiles = children.filter(f => f.getName().contains(FPGA))
+ assert(gpuFiles.length == 2)
+ assert(fpgaFiles.length == 2)
+ }
+ }
+ }
+
+ test("single plugin gpu") {
+ assume(!(Utils.isWindows))
+ withTempDir { dir =>
+ val conf = new SparkConf()
+ .setAppName(getClass().getName())
+ .set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]")
+ .set(RESOURCES_DISCOVERY_PLUGIN, Seq(classOf[TestResourceDiscoveryPluginGPU].getName()))
+ .set(TestResourceDiscoveryPlugin.TEST_PATH_CONF, dir.getAbsolutePath())
+ .set(WORKER_GPU_ID.amountConf, "2")
+ .set(TASK_GPU_ID.amountConf, "1")
+ .set(EXECUTOR_GPU_ID.amountConf, "1")
+ .set(SPARK_RESOURCES_DIR, dir.getName())
+
+ sc = new SparkContext(conf)
+ TestUtils.waitUntilExecutorsUp(sc, 2, 10000)
+
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ val children = dir.listFiles()
+ assert(children != null)
+ assert(children.length >= 2)
+ val gpuFiles = children.filter(f => f.getName().contains(GPU))
+ assert(gpuFiles.length == 2)
+ }
+ }
+ }
+
+ test("multiple plugins with one empty") {
+ assume(!(Utils.isWindows))
+ withTempDir { dir =>
+ val conf = new SparkConf()
+ .setAppName(getClass().getName())
+ .set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]")
+ .set(RESOURCES_DISCOVERY_PLUGIN, Seq(classOf[TestResourceDiscoveryPluginEmpty].getName(),
+ classOf[TestResourceDiscoveryPluginGPU].getName()))
+ .set(TestResourceDiscoveryPlugin.TEST_PATH_CONF, dir.getAbsolutePath())
+ .set(WORKER_GPU_ID.amountConf, "2")
+ .set(TASK_GPU_ID.amountConf, "1")
+ .set(EXECUTOR_GPU_ID.amountConf, "1")
+ .set(SPARK_RESOURCES_DIR, dir.getName())
+
+ sc = new SparkContext(conf)
+ TestUtils.waitUntilExecutorsUp(sc, 2, 10000)
+
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ val children = dir.listFiles()
+ assert(children != null)
+ assert(children.length >= 2)
+ val gpuFiles = children.filter(f => f.getName().contains(GPU))
+ assert(gpuFiles.length == 2)
+ }
+ }
+ }
+
+ test("empty plugin fallback to discovery script") {
+ assume(!(Utils.isWindows))
+ withTempDir { dir =>
+ val scriptPath = createTempScriptWithExpectedOutput(dir, "gpuDiscoveryScript",
+ """{"name": "gpu","addresses":["5", "6"]}""")
+ val conf = new SparkConf()
+ .setAppName(getClass().getName())
+ .set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]")
+ .set(RESOURCES_DISCOVERY_PLUGIN, Seq(classOf[TestResourceDiscoveryPluginEmpty].getName()))
+ .set(DRIVER_GPU_ID.discoveryScriptConf, scriptPath)
+ .set(DRIVER_GPU_ID.amountConf, "2")
+ .set(SPARK_RESOURCES_DIR, dir.getName())
+
+ sc = new SparkContext(conf)
+ TestUtils.waitUntilExecutorsUp(sc, 2, 10000)
+
+ assert(sc.resources.size === 1)
+ assert(sc.resources.get(GPU).get.addresses === Array("5", "6"))
+ assert(sc.resources.get(GPU).get.name === "gpu")
+ }
+ }
+}
+
+object TestResourceDiscoveryPlugin {
+ val TEST_PATH_CONF = "spark.nonLocalDiscoveryPlugin.path"
+
+ def writeFile(conf: SparkConf, id: String): Unit = {
+ val path = conf.get(TEST_PATH_CONF)
+ val fileName = s"$id - ${UUID.randomUUID.toString}"
+ Files.write(id, new File(path, fileName), StandardCharsets.UTF_8)
+ }
+}
+
+private class TestResourceDiscoveryPluginGPU extends ResourceDiscoveryPlugin {
+
+ override def discoverResource(
+ request: ResourceRequest,
+ conf: SparkConf): Optional[ResourceInformation] = {
+ if (request.id.resourceName.equals(GPU)) {
+ TestResourceDiscoveryPlugin.writeFile(conf, request.id.resourceName)
+ Optional.of(new ResourceInformation(GPU, Array("0", "1", "2", "3")))
+ } else {
+ Optional.empty()
+ }
+ }
+}
+
+private class TestResourceDiscoveryPluginEmpty extends ResourceDiscoveryPlugin {
+
+ override def discoverResource(
+ request: ResourceRequest,
+ conf: SparkConf): Optional[ResourceInformation] = {
+ Optional.empty()
+ }
+}
+
+private class TestResourceDiscoveryPluginFPGA extends ResourceDiscoveryPlugin {
+
+ override def discoverResource(
+ request: ResourceRequest,
+ conf: SparkConf): Optional[ResourceInformation] = {
+ if (request.id.resourceName.equals(FPGA)) {
+ TestResourceDiscoveryPlugin.writeFile(conf, request.id.resourceName)
+ Optional.of(new ResourceInformation(FPGA, Array("0", "1", "2", "3")))
+ } else {
+ Optional.empty()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceUtilsSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceUtilsSuite.scala
index b809469fd7..dffe9a02e9 100644
--- a/core/src/test/scala/org/apache/spark/resource/ResourceUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/resource/ResourceUtilsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.resource
import java.io.File
import java.nio.file.{Files => JavaFiles}
+import java.util.Optional
import org.json4s.{DefaultFormats, Extraction}
@@ -35,7 +36,7 @@ class ResourceUtilsSuite extends SparkFunSuite
test("ResourceID") {
val componentName = "spark.test"
val resourceName = "p100"
- val id = ResourceID(componentName, resourceName)
+ val id = new ResourceID(componentName, resourceName)
val confPrefix = s"$componentName.resource.$resourceName."
assert(id.confPrefix === confPrefix)
assert(id.amountConf === s"${confPrefix}amount")
@@ -91,7 +92,7 @@ class ResourceUtilsSuite extends SparkFunSuite
// test one with amount 0 to make sure ignored
val fooDiscovery = createTempScriptWithExpectedOutput(dir, "fooDiscoverScript",
"""{"name": "foo", "addresses": ["f1", "f2", "f3"]}""")
- val fooId = ResourceID(SPARK_EXECUTOR_PREFIX, "foo")
+ val fooId = new ResourceID(SPARK_EXECUTOR_PREFIX, "foo")
conf.set(fooId.amountConf, "0")
conf.set(fooId.discoveryScriptConf, fooDiscovery)
@@ -153,7 +154,8 @@ class ResourceUtilsSuite extends SparkFunSuite
val resourcesFromFileOnly = getOrDiscoverAllResourcesForResourceProfile(
Some(resourcesFile),
SPARK_EXECUTOR_PREFIX,
- ResourceProfile.getOrCreateDefaultProfile(conf))
+ ResourceProfile.getOrCreateDefaultProfile(conf),
+ conf)
val expectedFpgaInfo = new ResourceInformation(FPGA, fpgaAddrs.toArray)
assert(resourcesFromFileOnly(FPGA) === expectedFpgaInfo)
@@ -165,7 +167,7 @@ class ResourceUtilsSuite extends SparkFunSuite
val treqs = new TaskResourceRequests().resource(GPU, 1)
val rp = rpBuilder.require(ereqs).require(treqs).build
val resourcesFromBoth = getOrDiscoverAllResourcesForResourceProfile(
- Some(resourcesFile), SPARK_EXECUTOR_PREFIX, rp)
+ Some(resourcesFile), SPARK_EXECUTOR_PREFIX, rp, conf)
val expectedGpuInfo = new ResourceInformation(GPU, Array("0", "1"))
assert(resourcesFromBoth(FPGA) === expectedFpgaInfo)
assert(resourcesFromBoth(GPU) === expectedGpuInfo)
@@ -193,8 +195,8 @@ class ResourceUtilsSuite extends SparkFunSuite
var request = parseResourceRequest(conf, DRIVER_GPU_ID)
assert(request.id.resourceName === GPU, "should only have GPU for resource")
assert(request.amount === 2, "GPU count should be 2")
- assert(request.discoveryScript === None, "discovery script should be empty")
- assert(request.vendor === None, "vendor should be empty")
+ assert(request.discoveryScript === Optional.empty(), "discovery script should be empty")
+ assert(request.vendor === Optional.empty(), "vendor should be empty")
val vendor = "nvidia.com"
val discoveryScript = "discoveryScriptGPU"
@@ -240,14 +242,14 @@ class ResourceUtilsSuite extends SparkFunSuite
val gpuDiscovery = createTempScriptWithExpectedOutput(dir, "gpuDiscoveryScript",
"""{"name": "fpga", "addresses": ["0", "1"]}""")
val request =
- ResourceRequest(
+ new ResourceRequest(
DRIVER_GPU_ID,
2,
- Some(gpuDiscovery),
- None)
+ Optional.of(gpuDiscovery),
+ Optional.empty[String])
val error = intercept[SparkException] {
- discoverResource(request)
+ discoverResource(conf, request)
}.getMessage()
assert(error.contains(s"Error running the resource discovery script $gpuDiscovery: " +
@@ -255,6 +257,28 @@ class ResourceUtilsSuite extends SparkFunSuite
}
}
+ test("Resource discoverer with invalid class") {
+ val conf = new SparkConf()
+ .set(RESOURCES_DISCOVERY_PLUGIN, Seq("someinvalidclass"))
+ assume(!(Utils.isWindows))
+ withTempDir { dir =>
+ val gpuDiscovery = createTempScriptWithExpectedOutput(dir, "gpuDiscoveryScript",
+ """{"name": "fpga", "addresses": ["0", "1"]}""")
+ val request =
+ new ResourceRequest(
+ DRIVER_GPU_ID,
+ 2,
+ Optional.of(gpuDiscovery),
+ Optional.empty[String])
+
+ val error = intercept[ClassNotFoundException] {
+ discoverResource(conf, request)
+ }.getMessage()
+
+ assert(error.contains(s"someinvalidclass"))
+ }
+ }
+
test("Resource discoverer script returns invalid format") {
val conf = new SparkConf
assume(!(Utils.isWindows))
@@ -263,14 +287,14 @@ class ResourceUtilsSuite extends SparkFunSuite
"""{"addresses": ["0", "1"]}""")
val request =
- ResourceRequest(
+ new ResourceRequest(
EXECUTOR_GPU_ID,
2,
- Some(gpuDiscovery),
- None)
+ Optional.of(gpuDiscovery),
+ Optional.empty[String])
val error = intercept[SparkException] {
- discoverResource(request)
+ discoverResource(conf, request)
}.getMessage()
assert(error.contains("Error parsing JSON into ResourceInformation"))
@@ -283,14 +307,14 @@ class ResourceUtilsSuite extends SparkFunSuite
val file1 = new File(dir, "bogusfilepath")
try {
val request =
- ResourceRequest(
+ new ResourceRequest(
EXECUTOR_GPU_ID,
2,
- Some(file1.getPath()),
- None)
+ Optional.of(file1.getPath()),
+ Optional.empty[String])
val error = intercept[SparkException] {
- discoverResource(request)
+ discoverResource(conf, request)
}.getMessage()
assert(error.contains("doesn't exist"))
@@ -301,10 +325,11 @@ class ResourceUtilsSuite extends SparkFunSuite
}
test("gpu's specified but not a discovery script") {
- val request = ResourceRequest(EXECUTOR_GPU_ID, 2, None, None)
+ val request = new ResourceRequest(EXECUTOR_GPU_ID, 2, Optional.empty[String],
+ Optional.empty[String])
val error = intercept[SparkException] {
- discoverResource(request)
+ discoverResource(new SparkConf(), request)
}.getMessage()
assert(error.contains("User is expecting to use resource: gpu, but " +
diff --git a/core/src/test/scala/org/apache/spark/resource/TestResourceIDs.scala b/core/src/test/scala/org/apache/spark/resource/TestResourceIDs.scala
index c4509e9310..60246f5fad 100644
--- a/core/src/test/scala/org/apache/spark/resource/TestResourceIDs.scala
+++ b/core/src/test/scala/org/apache/spark/resource/TestResourceIDs.scala
@@ -22,14 +22,14 @@ import org.apache.spark.internal.config.Worker.SPARK_WORKER_PREFIX
import org.apache.spark.resource.ResourceUtils.{FPGA, GPU}
object TestResourceIDs {
- val DRIVER_GPU_ID = ResourceID(SPARK_DRIVER_PREFIX, GPU)
- val EXECUTOR_GPU_ID = ResourceID(SPARK_EXECUTOR_PREFIX, GPU)
- val TASK_GPU_ID = ResourceID(SPARK_TASK_PREFIX, GPU)
- val WORKER_GPU_ID = ResourceID(SPARK_WORKER_PREFIX, GPU)
+ val DRIVER_GPU_ID = new ResourceID(SPARK_DRIVER_PREFIX, GPU)
+ val EXECUTOR_GPU_ID = new ResourceID(SPARK_EXECUTOR_PREFIX, GPU)
+ val TASK_GPU_ID = new ResourceID(SPARK_TASK_PREFIX, GPU)
+ val WORKER_GPU_ID = new ResourceID(SPARK_WORKER_PREFIX, GPU)
- val DRIVER_FPGA_ID = ResourceID(SPARK_DRIVER_PREFIX, FPGA)
- val EXECUTOR_FPGA_ID = ResourceID(SPARK_EXECUTOR_PREFIX, FPGA)
- val TASK_FPGA_ID = ResourceID(SPARK_TASK_PREFIX, FPGA)
- val WORKER_FPGA_ID = ResourceID(SPARK_WORKER_PREFIX, FPGA)
+ val DRIVER_FPGA_ID = new ResourceID(SPARK_DRIVER_PREFIX, FPGA)
+ val EXECUTOR_FPGA_ID = new ResourceID(SPARK_EXECUTOR_PREFIX, FPGA)
+ val TASK_FPGA_ID = new ResourceID(SPARK_TASK_PREFIX, FPGA)
+ val WORKER_FPGA_ID = new ResourceID(SPARK_WORKER_PREFIX, FPGA)
}
diff --git a/docs/configuration.md b/docs/configuration.md
index 8164ed491d..2febfe9744 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -243,6 +243,18 @@ of the most common options to set are:
this config would be set to nvidia.com or amd.com)
+
+ spark.resources.discovery.plugin |
+ org.apache.spark.resource.ResourceDiscoveryScriptPlugin |
+
+ Comma-separated list of class names implementing
+ org.apache.spark.api.resource.ResourceDiscoveryPlugin to load into the application.
+ This is for advanced users to replace the resource discovery class with a
+ custom implementation. Spark will try each class specified until one of them
+ returns the resource information for that resource. It tries the discovery
+ script last if none of the plugins return information for that resource.
+ |
+
spark.executor.memory |
1g |
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
index b1b7751b01..e234b1780a 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
@@ -228,8 +228,12 @@ private[spark] object KubernetesUtils extends Logging {
sparkConf: SparkConf): Map[String, Quantity] = {
val requests = ResourceUtils.parseAllResourceRequests(sparkConf, componentName)
requests.map { request =>
- val vendorDomain = request.vendor.getOrElse(throw new SparkException("Resource: " +
- s"${request.id.resourceName} was requested, but vendor was not specified."))
+ val vendorDomain = if (request.vendor.isPresent()) {
+ request.vendor.get()
+ } else {
+ throw new SparkException(s"Resource: ${request.id.resourceName} was requested, " +
+ "but vendor was not specified.")
+ }
val quantity = new QuantityBuilder(false)
.withAmount(request.amount.toString)
.build()
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
index 710f28a859..ce66afd944 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
@@ -47,7 +47,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
}
test("Check the pod respects all configurations from the user.") {
- val resourceID = ResourceID(SPARK_DRIVER_PREFIX, GPU)
+ val resourceID = new ResourceID(SPARK_DRIVER_PREFIX, GPU)
val resources =
Map(("nvidia.com/gpu" -> TestResourceInformation(resourceID, "2", "nvidia.com")))
val sparkConf = new SparkConf()
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
index 51067bd889..f375b1fe6a 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
@@ -115,8 +115,8 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
}
test("basic executor pod with resources") {
- val fpgaResourceID = ResourceID(SPARK_EXECUTOR_PREFIX, FPGA)
- val gpuExecutorResourceID = ResourceID(SPARK_EXECUTOR_PREFIX, GPU)
+ val fpgaResourceID = new ResourceID(SPARK_EXECUTOR_PREFIX, FPGA)
+ val gpuExecutorResourceID = new ResourceID(SPARK_EXECUTOR_PREFIX, GPU)
val gpuResources =
Map(("nvidia.com/gpu" -> TestResourceInformation(gpuExecutorResourceID, "2", "nvidia.com")),
("foo.com/fpga" -> TestResourceInformation(fpgaResourceID, "1", "foo.com")))
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala
index f524962141..ae316b02ee 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala
@@ -40,8 +40,10 @@ import org.apache.spark.util.{CausedBy, Utils}
private object ResourceRequestHelper extends Logging {
private val AMOUNT_AND_UNIT_REGEX = "([0-9]+)([A-Za-z]*)".r
private val RESOURCE_INFO_CLASS = "org.apache.hadoop.yarn.api.records.ResourceInformation"
+ private val RESOURCE_NOT_FOUND = "org.apache.hadoop.yarn.exceptions.ResourceNotFoundException"
val YARN_GPU_RESOURCE_CONFIG = "yarn.io/gpu"
val YARN_FPGA_RESOURCE_CONFIG = "yarn.io/fpga"
+ @volatile private var numResourceErrors: Int = 0
private[yarn] def getYarnResourcesAndAmounts(
sparkConf: SparkConf,
@@ -76,7 +78,7 @@ private object ResourceRequestHelper extends Logging {
): Map[String, String] = {
Map(GPU -> YARN_GPU_RESOURCE_CONFIG, FPGA -> YARN_FPGA_RESOURCE_CONFIG).map {
case (rName, yarnName) =>
- (yarnName -> sparkConf.get(ResourceID(confPrefix, rName).amountConf, "0"))
+ (yarnName -> sparkConf.get(new ResourceID(confPrefix, rName).amountConf, "0"))
}.filter { case (_, count) => count.toLong > 0 }
}
@@ -108,13 +110,13 @@ private object ResourceRequestHelper extends Logging {
(AM_CORES.key, YARN_AM_RESOURCE_TYPES_PREFIX + "cpu-vcores"),
(DRIVER_CORES.key, YARN_DRIVER_RESOURCE_TYPES_PREFIX + "cpu-vcores"),
(EXECUTOR_CORES.key, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "cpu-vcores"),
- (ResourceID(SPARK_EXECUTOR_PREFIX, "fpga").amountConf,
+ (new ResourceID(SPARK_EXECUTOR_PREFIX, "fpga").amountConf,
s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${YARN_FPGA_RESOURCE_CONFIG}"),
- (ResourceID(SPARK_DRIVER_PREFIX, "fpga").amountConf,
+ (new ResourceID(SPARK_DRIVER_PREFIX, "fpga").amountConf,
s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${YARN_FPGA_RESOURCE_CONFIG}"),
- (ResourceID(SPARK_EXECUTOR_PREFIX, "gpu").amountConf,
+ (new ResourceID(SPARK_EXECUTOR_PREFIX, "gpu").amountConf,
s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${YARN_GPU_RESOURCE_CONFIG}"),
- (ResourceID(SPARK_DRIVER_PREFIX, "gpu").amountConf,
+ (new ResourceID(SPARK_DRIVER_PREFIX, "gpu").amountConf,
s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${YARN_GPU_RESOURCE_CONFIG}"))
val errorMessage = new mutable.StringBuilder()
@@ -185,7 +187,24 @@ private object ResourceRequestHelper extends Logging {
s"does not match pattern $AMOUNT_AND_UNIT_REGEX.")
case CausedBy(e: IllegalArgumentException) =>
throw new IllegalArgumentException(s"Invalid request for $name: ${e.getMessage}")
- case e: InvocationTargetException if e.getCause != null => throw e.getCause
+ case e: InvocationTargetException =>
+ if (e.getCause != null) {
+ if (Try(Utils.classForName(RESOURCE_NOT_FOUND)).isSuccess) {
+ if (e.getCause().getClass().getName().equals(RESOURCE_NOT_FOUND)) {
+ // warn a couple times and then stop so we don't spam the logs
+ if (numResourceErrors < 2) {
+ logWarning(s"YARN doesn't know about resource $name, your resource discovery " +
+ s"has to handle properly discovering and isolating the resource! Error: " +
+ s"${e.getCause().getMessage}")
+ numResourceErrors += 1
+ }
+ } else {
+ throw e.getCause
+ }
+ } else {
+ throw e.getCause
+ }
+ }
}
}
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 7cce908cd5..b42c8b933d 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -403,7 +403,7 @@ class ClientSuite extends SparkFunSuite with Matchers {
conf.set(s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${yarnName}.${AMOUNT}", "2")
}
resources.values.foreach { rName =>
- conf.set(ResourceID(SPARK_DRIVER_PREFIX, rName).amountConf, "3")
+ conf.set(new ResourceID(SPARK_DRIVER_PREFIX, rName).amountConf, "3")
}
val error = intercept[SparkException] {
@@ -426,7 +426,7 @@ class ClientSuite extends SparkFunSuite with Matchers {
conf.set(s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${yarnName}.${AMOUNT}", "2")
}
resources.values.foreach { rName =>
- conf.set(ResourceID(SPARK_EXECUTOR_PREFIX, rName).amountConf, "3")
+ conf.set(new ResourceID(SPARK_EXECUTOR_PREFIX, rName).amountConf, "3")
}
val error = intercept[SparkException] {
@@ -450,7 +450,7 @@ class ClientSuite extends SparkFunSuite with Matchers {
val conf = new SparkConf().set(SUBMIT_DEPLOY_MODE, "cluster")
resources.values.foreach { rName =>
- conf.set(ResourceID(SPARK_DRIVER_PREFIX, rName).amountConf, "3")
+ conf.set(new ResourceID(SPARK_DRIVER_PREFIX, rName).amountConf, "3")
}
// also just set yarn one that we don't convert
conf.set(s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${yarnMadeupResource}.${AMOUNT}", "5")