diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 3e089b4cae..0792b58304 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -140,6 +140,8 @@ public class TransportRequestHandler extends MessageHandler { streamManager.streamSent(req.streamId); }); } else { + // org.apache.spark.repl.ExecutorClassLoader.STREAM_NOT_FOUND_REGEX should also be updated + // when the following error message is changed. respond(new StreamFailure(req.streamId, String.format( "Stream '%s' was not found.", req.streamId))); } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 177bce2f00..0cfd96193d 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -21,6 +21,8 @@ import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels +import scala.util.control.NonFatal + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.xbean.asm7._ import org.apache.xbean.asm7.Opcodes._ @@ -106,7 +108,17 @@ class ExecutorClassLoader( parentLoader.loadClass(name) } catch { case e: ClassNotFoundException => - val classOption = findClassLocally(name) + val classOption = try { + findClassLocally(name) + } catch { + case e: RemoteClassLoaderError => + throw e + case NonFatal(e) => + // Wrap the error to include the class name + // scalastyle:off throwerror + throw new RemoteClassLoaderError(name, e) + // scalastyle:on throwerror + } classOption match { case None => throw new ClassNotFoundException(name, e) case Some(a) => a @@ -115,14 +127,15 @@ class ExecutorClassLoader( } } + // See org.apache.spark.network.server.TransportRequestHandler.processStreamRequest. + private val STREAM_NOT_FOUND_REGEX = s"Stream '.*' was not found.".r.pattern + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { - val channel = env.rpcEnv.openChannel(s"$classUri/$path") + val channel = env.rpcEnv.openChannel(s"$classUri/${urlEncode(path)}") new FilterInputStream(Channels.newInputStream(channel)) { override def read(): Int = toClassNotFound(super.read()) - override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b)) - override def read(b: Array[Byte], offset: Int, len: Int) = toClassNotFound(super.read(b, offset, len)) @@ -130,8 +143,15 @@ class ExecutorClassLoader( try { fn } catch { - case e: Exception => + case e: RuntimeException if e.getMessage != null + && STREAM_NOT_FOUND_REGEX.matcher(e.getMessage).matches() => + // Convert a stream not found error to ClassNotFoundException. + // Driver sends this explicit acknowledgment to tell us that the class was missing. throw new ClassNotFoundException(path, e) + case NonFatal(e) => + // scalastyle:off throwerror + throw new RemoteClassLoaderError(path, e) + // scalastyle:on throwerror } } } @@ -163,7 +183,12 @@ class ExecutorClassLoader( case e: Exception => // Something bad happened while checking if the class exists logError(s"Failed to check existence of class $name on REPL class server at $uri", e) - None + if (userClassPathFirst) { + // Allow to try to load from "parentLoader" + None + } else { + throw e + } } finally { if (inputStream != null) { try { @@ -237,3 +262,11 @@ extends ClassVisitor(ASM7, cv) { } } } + +/** + * An error when we cannot load a class due to exceptions. We don't know if this class exists, so + * throw a special one that's neither [[LinkageError]] nor [[ClassNotFoundException]] to make JVM + * retry to load this class later. + */ +private[repl] class RemoteClassLoaderError(className: String, cause: Throwable) + extends Error(className, cause) diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 0276f2dd40..dceae13fd8 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.repl -import java.io.File +import java.io.{File, IOException} +import java.lang.reflect.InvocationTargetException import java.net.{URI, URL, URLClassLoader} -import java.nio.channels.FileChannel +import java.nio.channels.{FileChannel, ReadableByteChannel} import java.nio.charset.StandardCharsets import java.nio.file.{Paths, StandardOpenOption} import java.util @@ -30,13 +31,15 @@ import scala.io.Source import scala.language.implicitConversions import com.google.common.io.Files -import org.mockito.ArgumentMatchers.anyString +import org.mockito.ArgumentMatchers.{any, anyString} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterAll import org.scalatest.mockito.MockitoSugar import org.apache.spark._ +import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils @@ -193,7 +196,14 @@ class ExecutorClassLoaderSuite when(rpcEnv.openChannel(anyString())).thenAnswer((invocation: InvocationOnMock) => { val uri = new URI(invocation.getArguments()(0).asInstanceOf[String]) val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/")) - FileChannel.open(path, StandardOpenOption.READ) + if (path.toFile.exists()) { + FileChannel.open(path, StandardOpenOption.READ) + } else { + val channel = mock[ReadableByteChannel] + when(channel.read(any())) + .thenThrow(new RuntimeException(s"Stream '${uri.getPath}' was not found.")) + channel + } }) val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", @@ -218,4 +228,131 @@ class ExecutorClassLoaderSuite } } + test("nonexistent class and transient errors should cause different errors") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("executor-class-loader-test") + .set("spark.network.timeout", "11s") + .set("spark.repl.class.outputDir", tempDir1.getAbsolutePath) + val sc = new SparkContext(conf) + try { + val replClassUri = sc.conf.get("spark.repl.class.uri") + + // Create an RpcEnv for executor + val rpcEnv = RpcEnv.create( + SparkEnv.executorSystemName, + "localhost", + "localhost", + 0, + sc.conf, + new SecurityManager(conf), 0, clientMode = true) + + try { + val env = mock[SparkEnv] + when(env.rpcEnv).thenReturn(rpcEnv) + + val classLoader = new ExecutorClassLoader( + conf, + env, + replClassUri, + getClass().getClassLoader(), + false) + + // Test loading a nonexistent class + intercept[java.lang.ClassNotFoundException] { + classLoader.loadClass("NonexistentClass") + } + + // Stop SparkContext to simulate transient errors in executors + sc.stop() + + val e = intercept[RemoteClassLoaderError] { + classLoader.loadClass("ThisIsAClassName") + } + assert(e.getMessage.contains("ThisIsAClassName")) + // RemoteClassLoaderError must not be LinkageError nor ClassNotFoundException. Otherwise, + // JVM will cache it and doesn't retry to load a class. + assert(!e.isInstanceOf[LinkageError] && !e.isInstanceOf[ClassNotFoundException]) + } finally { + rpcEnv.shutdown() + rpcEnv.awaitTermination() + } + } finally { + sc.stop() + } + } + + test("SPARK-20547 ExecutorClassLoader should not throw ClassNotFoundException without " + + "acknowledgment from driver") { + val tempDir = Utils.createTempDir() + try { + // Create two classes, "TestClassB" calls "TestClassA", so when calling "TestClassB.foo", JVM + // will try to load "TestClassA". + val sourceCodeOfClassA = + """public class TestClassA implements java.io.Serializable { + | @Override public String toString() { return "TestClassA"; } + |}""".stripMargin + val sourceFileA = new JavaSourceFromString("TestClassA", sourceCodeOfClassA) + TestUtils.createCompiledClass( + sourceFileA.name, tempDir, sourceFileA, Seq(tempDir.toURI.toURL)) + + val sourceCodeOfClassB = + """public class TestClassB implements java.io.Serializable { + | public String foo() { return new TestClassA().toString(); } + | @Override public String toString() { return "TestClassB"; } + |}""".stripMargin + val sourceFileB = new JavaSourceFromString("TestClassB", sourceCodeOfClassB) + TestUtils.createCompiledClass( + sourceFileB.name, tempDir, sourceFileB, Seq(tempDir.toURI.toURL)) + + val env = mock[SparkEnv] + val rpcEnv = mock[RpcEnv] + when(env.rpcEnv).thenReturn(rpcEnv) + when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() { + private var count = 0 + + override def answer(invocation: InvocationOnMock): ReadableByteChannel = { + val uri = new URI(invocation.getArguments()(0).asInstanceOf[String]) + val classFileName = uri.getPath().stripPrefix("/") + if (count == 0 && classFileName == "TestClassA.class") { + count += 1 + // Let the first attempt to load TestClassA fail with an IOException + val channel = mock[ReadableByteChannel] + when(channel.read(any())).thenThrow(new IOException("broken pipe")) + channel + } + else { + val path = Paths.get(tempDir.getAbsolutePath(), classFileName) + FileChannel.open(path, StandardOpenOption.READ) + } + } + }) + + val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", + getClass().getClassLoader(), false) + + def callClassBFoo(): String = { + // scalastyle:off classforname + val classB = Class.forName("TestClassB", true, classLoader) + // scalastyle:on classforname + val instanceOfTestClassB = classB.newInstance() + assert(instanceOfTestClassB.toString === "TestClassB") + classB.getMethod("foo").invoke(instanceOfTestClassB).asInstanceOf[String] + } + + // Reflection will wrap the exception with InvocationTargetException + val e = intercept[InvocationTargetException] { + callClassBFoo() + } + // "TestClassA" cannot be loaded because of IOException + assert(e.getCause.isInstanceOf[RemoteClassLoaderError]) + assert(e.getCause.getCause.isInstanceOf[IOException]) + assert(e.getCause.getMessage.contains("TestClassA")) + + // We should be able to re-load TestClassA for IOException + assert(callClassBFoo() === "TestClassA") + } finally { + Utils.deleteRecursively(tempDir) + } + } } diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 4849c7cdc7..38e3fc4f93 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,12 +22,27 @@ import java.io._ import scala.tools.nsc.interpreter.SimpleReader import org.apache.log4j.{Level, LogManager} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -class ReplSuite extends SparkFunSuite { +class ReplSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var originalClassLoader: ClassLoader = null + + override def beforeAll(): Unit = { + originalClassLoader = Thread.currentThread().getContextClassLoader + } + + override def afterAll(): Unit = { + if (originalClassLoader != null) { + // Reset the class loader to not affect other suites. REPL will set its own class loader but + // doesn't reset it. + Thread.currentThread().setContextClassLoader(originalClassLoader) + } + } def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala index 7e3d0d9244..777de967b6 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -390,4 +390,20 @@ class SingletonReplSuite extends SparkFunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test("create encoder in executors") { + val output = runInterpreter( + """ + |case class Foo(s: String) + | + |import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder + | + |val r = + | sc.parallelize(1 to 1).map { i => ExpressionEncoder[Foo](); Foo("bar") }.collect.head + """.stripMargin) + + assertContains("r: Foo = Foo(bar)", output) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } }