[SPARK-20547][REPL] Throw RemoteClassLoadedError for transient errors in ExecutorClassLoader
## What changes were proposed in this pull request? `ExecutorClassLoader`'s `findClass` may fail to fetch a class due to transient exceptions. For example, when a task is interrupted, if `ExecutorClassLoader` is fetching a class, you may see `InterruptedException` or `IOException` wrapped by `ClassNotFoundException`, even if this class can be loaded. Then the result of `findClass` will be cached by JVM, and later when the same class is being loaded in the same executor, it will just throw NoClassDefFoundError even if the class can be loaded. I found JVM only caches `LinkageError` and `ClassNotFoundException`. Hence in this PR, I changed ExecutorClassLoader to throw `RemoteClassLoadedError` if we cannot get a response from driver. ## How was this patch tested? New unit tests. Closes #24683 from zsxwing/SPARK-20547-fix. Authored-by: Shixiong Zhu <zsxwing@gmail.com> Signed-off-by: Shixiong Zhu <zsxwing@gmail.com>
This commit is contained in:
parent
4e61de4380
commit
04f142db9c
|
@ -140,6 +140,8 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
|
||||||
streamManager.streamSent(req.streamId);
|
streamManager.streamSent(req.streamId);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
// org.apache.spark.repl.ExecutorClassLoader.STREAM_NOT_FOUND_REGEX should also be updated
|
||||||
|
// when the following error message is changed.
|
||||||
respond(new StreamFailure(req.streamId, String.format(
|
respond(new StreamFailure(req.streamId, String.format(
|
||||||
"Stream '%s' was not found.", req.streamId)));
|
"Stream '%s' was not found.", req.streamId)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,8 @@ import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream,
|
||||||
import java.net.{URI, URL, URLEncoder}
|
import java.net.{URI, URL, URLEncoder}
|
||||||
import java.nio.channels.Channels
|
import java.nio.channels.Channels
|
||||||
|
|
||||||
|
import scala.util.control.NonFatal
|
||||||
|
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
import org.apache.xbean.asm7._
|
import org.apache.xbean.asm7._
|
||||||
import org.apache.xbean.asm7.Opcodes._
|
import org.apache.xbean.asm7.Opcodes._
|
||||||
|
@ -106,7 +108,17 @@ class ExecutorClassLoader(
|
||||||
parentLoader.loadClass(name)
|
parentLoader.loadClass(name)
|
||||||
} catch {
|
} catch {
|
||||||
case e: ClassNotFoundException =>
|
case e: ClassNotFoundException =>
|
||||||
val classOption = findClassLocally(name)
|
val classOption = try {
|
||||||
|
findClassLocally(name)
|
||||||
|
} catch {
|
||||||
|
case e: RemoteClassLoaderError =>
|
||||||
|
throw e
|
||||||
|
case NonFatal(e) =>
|
||||||
|
// Wrap the error to include the class name
|
||||||
|
// scalastyle:off throwerror
|
||||||
|
throw new RemoteClassLoaderError(name, e)
|
||||||
|
// scalastyle:on throwerror
|
||||||
|
}
|
||||||
classOption match {
|
classOption match {
|
||||||
case None => throw new ClassNotFoundException(name, e)
|
case None => throw new ClassNotFoundException(name, e)
|
||||||
case Some(a) => a
|
case Some(a) => a
|
||||||
|
@ -115,14 +127,15 @@ class ExecutorClassLoader(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// See org.apache.spark.network.server.TransportRequestHandler.processStreamRequest.
|
||||||
|
private val STREAM_NOT_FOUND_REGEX = s"Stream '.*' was not found.".r.pattern
|
||||||
|
|
||||||
private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = {
|
private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = {
|
||||||
val channel = env.rpcEnv.openChannel(s"$classUri/$path")
|
val channel = env.rpcEnv.openChannel(s"$classUri/${urlEncode(path)}")
|
||||||
new FilterInputStream(Channels.newInputStream(channel)) {
|
new FilterInputStream(Channels.newInputStream(channel)) {
|
||||||
|
|
||||||
override def read(): Int = toClassNotFound(super.read())
|
override def read(): Int = toClassNotFound(super.read())
|
||||||
|
|
||||||
override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b))
|
|
||||||
|
|
||||||
override def read(b: Array[Byte], offset: Int, len: Int) =
|
override def read(b: Array[Byte], offset: Int, len: Int) =
|
||||||
toClassNotFound(super.read(b, offset, len))
|
toClassNotFound(super.read(b, offset, len))
|
||||||
|
|
||||||
|
@ -130,8 +143,15 @@ class ExecutorClassLoader(
|
||||||
try {
|
try {
|
||||||
fn
|
fn
|
||||||
} catch {
|
} catch {
|
||||||
case e: Exception =>
|
case e: RuntimeException if e.getMessage != null
|
||||||
|
&& STREAM_NOT_FOUND_REGEX.matcher(e.getMessage).matches() =>
|
||||||
|
// Convert a stream not found error to ClassNotFoundException.
|
||||||
|
// Driver sends this explicit acknowledgment to tell us that the class was missing.
|
||||||
throw new ClassNotFoundException(path, e)
|
throw new ClassNotFoundException(path, e)
|
||||||
|
case NonFatal(e) =>
|
||||||
|
// scalastyle:off throwerror
|
||||||
|
throw new RemoteClassLoaderError(path, e)
|
||||||
|
// scalastyle:on throwerror
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -163,7 +183,12 @@ class ExecutorClassLoader(
|
||||||
case e: Exception =>
|
case e: Exception =>
|
||||||
// Something bad happened while checking if the class exists
|
// Something bad happened while checking if the class exists
|
||||||
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
|
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
|
||||||
|
if (userClassPathFirst) {
|
||||||
|
// Allow to try to load from "parentLoader"
|
||||||
None
|
None
|
||||||
|
} else {
|
||||||
|
throw e
|
||||||
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (inputStream != null) {
|
if (inputStream != null) {
|
||||||
try {
|
try {
|
||||||
|
@ -237,3 +262,11 @@ extends ClassVisitor(ASM7, cv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An error when we cannot load a class due to exceptions. We don't know if this class exists, so
|
||||||
|
* throw a special one that's neither [[LinkageError]] nor [[ClassNotFoundException]] to make JVM
|
||||||
|
* retry to load this class later.
|
||||||
|
*/
|
||||||
|
private[repl] class RemoteClassLoaderError(className: String, cause: Throwable)
|
||||||
|
extends Error(className, cause)
|
||||||
|
|
|
@ -17,9 +17,10 @@
|
||||||
|
|
||||||
package org.apache.spark.repl
|
package org.apache.spark.repl
|
||||||
|
|
||||||
import java.io.File
|
import java.io.{File, IOException}
|
||||||
|
import java.lang.reflect.InvocationTargetException
|
||||||
import java.net.{URI, URL, URLClassLoader}
|
import java.net.{URI, URL, URLClassLoader}
|
||||||
import java.nio.channels.FileChannel
|
import java.nio.channels.{FileChannel, ReadableByteChannel}
|
||||||
import java.nio.charset.StandardCharsets
|
import java.nio.charset.StandardCharsets
|
||||||
import java.nio.file.{Paths, StandardOpenOption}
|
import java.nio.file.{Paths, StandardOpenOption}
|
||||||
import java.util
|
import java.util
|
||||||
|
@ -30,13 +31,15 @@ import scala.io.Source
|
||||||
import scala.language.implicitConversions
|
import scala.language.implicitConversions
|
||||||
|
|
||||||
import com.google.common.io.Files
|
import com.google.common.io.Files
|
||||||
import org.mockito.ArgumentMatchers.anyString
|
import org.mockito.ArgumentMatchers.{any, anyString}
|
||||||
import org.mockito.Mockito._
|
import org.mockito.Mockito._
|
||||||
import org.mockito.invocation.InvocationOnMock
|
import org.mockito.invocation.InvocationOnMock
|
||||||
|
import org.mockito.stubbing.Answer
|
||||||
import org.scalatest.BeforeAndAfterAll
|
import org.scalatest.BeforeAndAfterAll
|
||||||
import org.scalatest.mockito.MockitoSugar
|
import org.scalatest.mockito.MockitoSugar
|
||||||
|
|
||||||
import org.apache.spark._
|
import org.apache.spark._
|
||||||
|
import org.apache.spark.TestUtils.JavaSourceFromString
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.rpc.RpcEnv
|
import org.apache.spark.rpc.RpcEnv
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
@ -193,7 +196,14 @@ class ExecutorClassLoaderSuite
|
||||||
when(rpcEnv.openChannel(anyString())).thenAnswer((invocation: InvocationOnMock) => {
|
when(rpcEnv.openChannel(anyString())).thenAnswer((invocation: InvocationOnMock) => {
|
||||||
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
|
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
|
||||||
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
|
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
|
||||||
|
if (path.toFile.exists()) {
|
||||||
FileChannel.open(path, StandardOpenOption.READ)
|
FileChannel.open(path, StandardOpenOption.READ)
|
||||||
|
} else {
|
||||||
|
val channel = mock[ReadableByteChannel]
|
||||||
|
when(channel.read(any()))
|
||||||
|
.thenThrow(new RuntimeException(s"Stream '${uri.getPath}' was not found."))
|
||||||
|
channel
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
|
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
|
||||||
|
@ -218,4 +228,131 @@ class ExecutorClassLoaderSuite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("nonexistent class and transient errors should cause different errors") {
|
||||||
|
val conf = new SparkConf()
|
||||||
|
.setMaster("local")
|
||||||
|
.setAppName("executor-class-loader-test")
|
||||||
|
.set("spark.network.timeout", "11s")
|
||||||
|
.set("spark.repl.class.outputDir", tempDir1.getAbsolutePath)
|
||||||
|
val sc = new SparkContext(conf)
|
||||||
|
try {
|
||||||
|
val replClassUri = sc.conf.get("spark.repl.class.uri")
|
||||||
|
|
||||||
|
// Create an RpcEnv for executor
|
||||||
|
val rpcEnv = RpcEnv.create(
|
||||||
|
SparkEnv.executorSystemName,
|
||||||
|
"localhost",
|
||||||
|
"localhost",
|
||||||
|
0,
|
||||||
|
sc.conf,
|
||||||
|
new SecurityManager(conf), 0, clientMode = true)
|
||||||
|
|
||||||
|
try {
|
||||||
|
val env = mock[SparkEnv]
|
||||||
|
when(env.rpcEnv).thenReturn(rpcEnv)
|
||||||
|
|
||||||
|
val classLoader = new ExecutorClassLoader(
|
||||||
|
conf,
|
||||||
|
env,
|
||||||
|
replClassUri,
|
||||||
|
getClass().getClassLoader(),
|
||||||
|
false)
|
||||||
|
|
||||||
|
// Test loading a nonexistent class
|
||||||
|
intercept[java.lang.ClassNotFoundException] {
|
||||||
|
classLoader.loadClass("NonexistentClass")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop SparkContext to simulate transient errors in executors
|
||||||
|
sc.stop()
|
||||||
|
|
||||||
|
val e = intercept[RemoteClassLoaderError] {
|
||||||
|
classLoader.loadClass("ThisIsAClassName")
|
||||||
|
}
|
||||||
|
assert(e.getMessage.contains("ThisIsAClassName"))
|
||||||
|
// RemoteClassLoaderError must not be LinkageError nor ClassNotFoundException. Otherwise,
|
||||||
|
// JVM will cache it and doesn't retry to load a class.
|
||||||
|
assert(!e.isInstanceOf[LinkageError] && !e.isInstanceOf[ClassNotFoundException])
|
||||||
|
} finally {
|
||||||
|
rpcEnv.shutdown()
|
||||||
|
rpcEnv.awaitTermination()
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
sc.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-20547 ExecutorClassLoader should not throw ClassNotFoundException without " +
|
||||||
|
"acknowledgment from driver") {
|
||||||
|
val tempDir = Utils.createTempDir()
|
||||||
|
try {
|
||||||
|
// Create two classes, "TestClassB" calls "TestClassA", so when calling "TestClassB.foo", JVM
|
||||||
|
// will try to load "TestClassA".
|
||||||
|
val sourceCodeOfClassA =
|
||||||
|
"""public class TestClassA implements java.io.Serializable {
|
||||||
|
| @Override public String toString() { return "TestClassA"; }
|
||||||
|
|}""".stripMargin
|
||||||
|
val sourceFileA = new JavaSourceFromString("TestClassA", sourceCodeOfClassA)
|
||||||
|
TestUtils.createCompiledClass(
|
||||||
|
sourceFileA.name, tempDir, sourceFileA, Seq(tempDir.toURI.toURL))
|
||||||
|
|
||||||
|
val sourceCodeOfClassB =
|
||||||
|
"""public class TestClassB implements java.io.Serializable {
|
||||||
|
| public String foo() { return new TestClassA().toString(); }
|
||||||
|
| @Override public String toString() { return "TestClassB"; }
|
||||||
|
|}""".stripMargin
|
||||||
|
val sourceFileB = new JavaSourceFromString("TestClassB", sourceCodeOfClassB)
|
||||||
|
TestUtils.createCompiledClass(
|
||||||
|
sourceFileB.name, tempDir, sourceFileB, Seq(tempDir.toURI.toURL))
|
||||||
|
|
||||||
|
val env = mock[SparkEnv]
|
||||||
|
val rpcEnv = mock[RpcEnv]
|
||||||
|
when(env.rpcEnv).thenReturn(rpcEnv)
|
||||||
|
when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() {
|
||||||
|
private var count = 0
|
||||||
|
|
||||||
|
override def answer(invocation: InvocationOnMock): ReadableByteChannel = {
|
||||||
|
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
|
||||||
|
val classFileName = uri.getPath().stripPrefix("/")
|
||||||
|
if (count == 0 && classFileName == "TestClassA.class") {
|
||||||
|
count += 1
|
||||||
|
// Let the first attempt to load TestClassA fail with an IOException
|
||||||
|
val channel = mock[ReadableByteChannel]
|
||||||
|
when(channel.read(any())).thenThrow(new IOException("broken pipe"))
|
||||||
|
channel
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
val path = Paths.get(tempDir.getAbsolutePath(), classFileName)
|
||||||
|
FileChannel.open(path, StandardOpenOption.READ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
|
||||||
|
getClass().getClassLoader(), false)
|
||||||
|
|
||||||
|
def callClassBFoo(): String = {
|
||||||
|
// scalastyle:off classforname
|
||||||
|
val classB = Class.forName("TestClassB", true, classLoader)
|
||||||
|
// scalastyle:on classforname
|
||||||
|
val instanceOfTestClassB = classB.newInstance()
|
||||||
|
assert(instanceOfTestClassB.toString === "TestClassB")
|
||||||
|
classB.getMethod("foo").invoke(instanceOfTestClassB).asInstanceOf[String]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reflection will wrap the exception with InvocationTargetException
|
||||||
|
val e = intercept[InvocationTargetException] {
|
||||||
|
callClassBFoo()
|
||||||
|
}
|
||||||
|
// "TestClassA" cannot be loaded because of IOException
|
||||||
|
assert(e.getCause.isInstanceOf[RemoteClassLoaderError])
|
||||||
|
assert(e.getCause.getCause.isInstanceOf[IOException])
|
||||||
|
assert(e.getCause.getMessage.contains("TestClassA"))
|
||||||
|
|
||||||
|
// We should be able to re-load TestClassA for IOException
|
||||||
|
assert(callClassBFoo() === "TestClassA")
|
||||||
|
} finally {
|
||||||
|
Utils.deleteRecursively(tempDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,12 +22,27 @@ import java.io._
|
||||||
import scala.tools.nsc.interpreter.SimpleReader
|
import scala.tools.nsc.interpreter.SimpleReader
|
||||||
|
|
||||||
import org.apache.log4j.{Level, LogManager}
|
import org.apache.log4j.{Level, LogManager}
|
||||||
|
import org.scalatest.BeforeAndAfterAll
|
||||||
|
|
||||||
import org.apache.spark.{SparkContext, SparkFunSuite}
|
import org.apache.spark.{SparkContext, SparkFunSuite}
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
|
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
|
||||||
|
|
||||||
class ReplSuite extends SparkFunSuite {
|
class ReplSuite extends SparkFunSuite with BeforeAndAfterAll {
|
||||||
|
|
||||||
|
private var originalClassLoader: ClassLoader = null
|
||||||
|
|
||||||
|
override def beforeAll(): Unit = {
|
||||||
|
originalClassLoader = Thread.currentThread().getContextClassLoader
|
||||||
|
}
|
||||||
|
|
||||||
|
override def afterAll(): Unit = {
|
||||||
|
if (originalClassLoader != null) {
|
||||||
|
// Reset the class loader to not affect other suites. REPL will set its own class loader but
|
||||||
|
// doesn't reset it.
|
||||||
|
Thread.currentThread().setContextClassLoader(originalClassLoader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def runInterpreter(master: String, input: String): String = {
|
def runInterpreter(master: String, input: String): String = {
|
||||||
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
|
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
|
||||||
|
|
|
@ -390,4 +390,20 @@ class SingletonReplSuite extends SparkFunSuite {
|
||||||
assertDoesNotContain("error:", output)
|
assertDoesNotContain("error:", output)
|
||||||
assertDoesNotContain("Exception", output)
|
assertDoesNotContain("Exception", output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("create encoder in executors") {
|
||||||
|
val output = runInterpreter(
|
||||||
|
"""
|
||||||
|
|case class Foo(s: String)
|
||||||
|
|
|
||||||
|
|import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||||
|
|
|
||||||
|
|val r =
|
||||||
|
| sc.parallelize(1 to 1).map { i => ExpressionEncoder[Foo](); Foo("bar") }.collect.head
|
||||||
|
""".stripMargin)
|
||||||
|
|
||||||
|
assertContains("r: Foo = Foo(bar)", output)
|
||||||
|
assertDoesNotContain("error:", output)
|
||||||
|
assertDoesNotContain("Exception", output)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue