[SPARK-20547][REPL] Throw RemoteClassLoadedError for transient errors in ExecutorClassLoader

## What changes were proposed in this pull request?

`ExecutorClassLoader`'s `findClass` may fail to fetch a class due to transient exceptions. For example, when a task is interrupted, if `ExecutorClassLoader` is fetching a class, you may see `InterruptedException` or `IOException` wrapped by `ClassNotFoundException`, even if this class can be loaded. Then the result of `findClass` will be cached by JVM, and later when the same class is being loaded in the same executor, it will just throw NoClassDefFoundError even if the class can be loaded.

I found JVM only caches `LinkageError` and `ClassNotFoundException`. Hence in this PR, I changed ExecutorClassLoader to throw `RemoteClassLoadedError` if we cannot get a response from driver.

## How was this patch tested?

New unit tests.

Closes #24683 from zsxwing/SPARK-20547-fix.

Authored-by: Shixiong Zhu <zsxwing@gmail.com>
Signed-off-by: Shixiong Zhu <zsxwing@gmail.com>
This commit is contained in:
Shixiong Zhu 2019-05-28 12:56:14 -07:00
parent 4e61de4380
commit 04f142db9c
No known key found for this signature in database
GPG key ID: 34400CF75FADFD94
5 changed files with 214 additions and 11 deletions

View file

@ -140,6 +140,8 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
streamManager.streamSent(req.streamId);
});
} else {
// org.apache.spark.repl.ExecutorClassLoader.STREAM_NOT_FOUND_REGEX should also be updated
// when the following error message is changed.
respond(new StreamFailure(req.streamId, String.format(
"Stream '%s' was not found.", req.streamId)));
}

View file

@ -21,6 +21,8 @@ import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream,
import java.net.{URI, URL, URLEncoder}
import java.nio.channels.Channels
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.xbean.asm7._
import org.apache.xbean.asm7.Opcodes._
@ -106,7 +108,17 @@ class ExecutorClassLoader(
parentLoader.loadClass(name)
} catch {
case e: ClassNotFoundException =>
val classOption = findClassLocally(name)
val classOption = try {
findClassLocally(name)
} catch {
case e: RemoteClassLoaderError =>
throw e
case NonFatal(e) =>
// Wrap the error to include the class name
// scalastyle:off throwerror
throw new RemoteClassLoaderError(name, e)
// scalastyle:on throwerror
}
classOption match {
case None => throw new ClassNotFoundException(name, e)
case Some(a) => a
@ -115,14 +127,15 @@ class ExecutorClassLoader(
}
}
// See org.apache.spark.network.server.TransportRequestHandler.processStreamRequest.
private val STREAM_NOT_FOUND_REGEX = s"Stream '.*' was not found.".r.pattern
private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = {
val channel = env.rpcEnv.openChannel(s"$classUri/$path")
val channel = env.rpcEnv.openChannel(s"$classUri/${urlEncode(path)}")
new FilterInputStream(Channels.newInputStream(channel)) {
override def read(): Int = toClassNotFound(super.read())
override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b))
override def read(b: Array[Byte], offset: Int, len: Int) =
toClassNotFound(super.read(b, offset, len))
@ -130,8 +143,15 @@ class ExecutorClassLoader(
try {
fn
} catch {
case e: Exception =>
case e: RuntimeException if e.getMessage != null
&& STREAM_NOT_FOUND_REGEX.matcher(e.getMessage).matches() =>
// Convert a stream not found error to ClassNotFoundException.
// Driver sends this explicit acknowledgment to tell us that the class was missing.
throw new ClassNotFoundException(path, e)
case NonFatal(e) =>
// scalastyle:off throwerror
throw new RemoteClassLoaderError(path, e)
// scalastyle:on throwerror
}
}
}
@ -163,7 +183,12 @@ class ExecutorClassLoader(
case e: Exception =>
// Something bad happened while checking if the class exists
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
if (userClassPathFirst) {
// Allow to try to load from "parentLoader"
None
} else {
throw e
}
} finally {
if (inputStream != null) {
try {
@ -237,3 +262,11 @@ extends ClassVisitor(ASM7, cv) {
}
}
}
/**
* An error when we cannot load a class due to exceptions. We don't know if this class exists, so
* throw a special one that's neither [[LinkageError]] nor [[ClassNotFoundException]] to make JVM
* retry to load this class later.
*/
private[repl] class RemoteClassLoaderError(className: String, cause: Throwable)
extends Error(className, cause)

View file

@ -17,9 +17,10 @@
package org.apache.spark.repl
import java.io.File
import java.io.{File, IOException}
import java.lang.reflect.InvocationTargetException
import java.net.{URI, URL, URLClassLoader}
import java.nio.channels.FileChannel
import java.nio.channels.{FileChannel, ReadableByteChannel}
import java.nio.charset.StandardCharsets
import java.nio.file.{Paths, StandardOpenOption}
import java.util
@ -30,13 +31,15 @@ import scala.io.Source
import scala.language.implicitConversions
import com.google.common.io.Files
import org.mockito.ArgumentMatchers.anyString
import org.mockito.ArgumentMatchers.{any, anyString}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfterAll
import org.scalatest.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.Utils
@ -193,7 +196,14 @@ class ExecutorClassLoaderSuite
when(rpcEnv.openChannel(anyString())).thenAnswer((invocation: InvocationOnMock) => {
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
if (path.toFile.exists()) {
FileChannel.open(path, StandardOpenOption.READ)
} else {
val channel = mock[ReadableByteChannel]
when(channel.read(any()))
.thenThrow(new RuntimeException(s"Stream '${uri.getPath}' was not found."))
channel
}
})
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
@ -218,4 +228,131 @@ class ExecutorClassLoaderSuite
}
}
test("nonexistent class and transient errors should cause different errors") {
val conf = new SparkConf()
.setMaster("local")
.setAppName("executor-class-loader-test")
.set("spark.network.timeout", "11s")
.set("spark.repl.class.outputDir", tempDir1.getAbsolutePath)
val sc = new SparkContext(conf)
try {
val replClassUri = sc.conf.get("spark.repl.class.uri")
// Create an RpcEnv for executor
val rpcEnv = RpcEnv.create(
SparkEnv.executorSystemName,
"localhost",
"localhost",
0,
sc.conf,
new SecurityManager(conf), 0, clientMode = true)
try {
val env = mock[SparkEnv]
when(env.rpcEnv).thenReturn(rpcEnv)
val classLoader = new ExecutorClassLoader(
conf,
env,
replClassUri,
getClass().getClassLoader(),
false)
// Test loading a nonexistent class
intercept[java.lang.ClassNotFoundException] {
classLoader.loadClass("NonexistentClass")
}
// Stop SparkContext to simulate transient errors in executors
sc.stop()
val e = intercept[RemoteClassLoaderError] {
classLoader.loadClass("ThisIsAClassName")
}
assert(e.getMessage.contains("ThisIsAClassName"))
// RemoteClassLoaderError must not be LinkageError nor ClassNotFoundException. Otherwise,
// JVM will cache it and doesn't retry to load a class.
assert(!e.isInstanceOf[LinkageError] && !e.isInstanceOf[ClassNotFoundException])
} finally {
rpcEnv.shutdown()
rpcEnv.awaitTermination()
}
} finally {
sc.stop()
}
}
test("SPARK-20547 ExecutorClassLoader should not throw ClassNotFoundException without " +
"acknowledgment from driver") {
val tempDir = Utils.createTempDir()
try {
// Create two classes, "TestClassB" calls "TestClassA", so when calling "TestClassB.foo", JVM
// will try to load "TestClassA".
val sourceCodeOfClassA =
"""public class TestClassA implements java.io.Serializable {
| @Override public String toString() { return "TestClassA"; }
|}""".stripMargin
val sourceFileA = new JavaSourceFromString("TestClassA", sourceCodeOfClassA)
TestUtils.createCompiledClass(
sourceFileA.name, tempDir, sourceFileA, Seq(tempDir.toURI.toURL))
val sourceCodeOfClassB =
"""public class TestClassB implements java.io.Serializable {
| public String foo() { return new TestClassA().toString(); }
| @Override public String toString() { return "TestClassB"; }
|}""".stripMargin
val sourceFileB = new JavaSourceFromString("TestClassB", sourceCodeOfClassB)
TestUtils.createCompiledClass(
sourceFileB.name, tempDir, sourceFileB, Seq(tempDir.toURI.toURL))
val env = mock[SparkEnv]
val rpcEnv = mock[RpcEnv]
when(env.rpcEnv).thenReturn(rpcEnv)
when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() {
private var count = 0
override def answer(invocation: InvocationOnMock): ReadableByteChannel = {
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
val classFileName = uri.getPath().stripPrefix("/")
if (count == 0 && classFileName == "TestClassA.class") {
count += 1
// Let the first attempt to load TestClassA fail with an IOException
val channel = mock[ReadableByteChannel]
when(channel.read(any())).thenThrow(new IOException("broken pipe"))
channel
}
else {
val path = Paths.get(tempDir.getAbsolutePath(), classFileName)
FileChannel.open(path, StandardOpenOption.READ)
}
}
})
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
getClass().getClassLoader(), false)
def callClassBFoo(): String = {
// scalastyle:off classforname
val classB = Class.forName("TestClassB", true, classLoader)
// scalastyle:on classforname
val instanceOfTestClassB = classB.newInstance()
assert(instanceOfTestClassB.toString === "TestClassB")
classB.getMethod("foo").invoke(instanceOfTestClassB).asInstanceOf[String]
}
// Reflection will wrap the exception with InvocationTargetException
val e = intercept[InvocationTargetException] {
callClassBFoo()
}
// "TestClassA" cannot be loaded because of IOException
assert(e.getCause.isInstanceOf[RemoteClassLoaderError])
assert(e.getCause.getCause.isInstanceOf[IOException])
assert(e.getCause.getMessage.contains("TestClassA"))
// We should be able to re-load TestClassA for IOException
assert(callClassBFoo() === "TestClassA")
} finally {
Utils.deleteRecursively(tempDir)
}
}
}

View file

@ -22,12 +22,27 @@ import java.io._
import scala.tools.nsc.interpreter.SimpleReader
import org.apache.log4j.{Level, LogManager}
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
class ReplSuite extends SparkFunSuite {
class ReplSuite extends SparkFunSuite with BeforeAndAfterAll {
private var originalClassLoader: ClassLoader = null
override def beforeAll(): Unit = {
originalClassLoader = Thread.currentThread().getContextClassLoader
}
override def afterAll(): Unit = {
if (originalClassLoader != null) {
// Reset the class loader to not affect other suites. REPL will set its own class loader but
// doesn't reset it.
Thread.currentThread().setContextClassLoader(originalClassLoader)
}
}
def runInterpreter(master: String, input: String): String = {
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"

View file

@ -390,4 +390,20 @@ class SingletonReplSuite extends SparkFunSuite {
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}
test("create encoder in executors") {
val output = runInterpreter(
"""
|case class Foo(s: String)
|
|import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
|val r =
| sc.parallelize(1 to 1).map { i => ExpressionEncoder[Foo](); Foo("bar") }.collect.head
""".stripMargin)
assertContains("r: Foo = Foo(bar)", output)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}
}