diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 3176502b9e..177bce2f00 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -33,8 +33,11 @@ import org.apache.spark.util.ParentClassLoader /** * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load * classes defined by the interpreter when the REPL is used. Allows the user to specify if user - * class path should be first. This class loader delegates getting/finding resources to parent - * loader, which makes sense until REPL never provide resource dynamically. + * class path should be first. + * This class loader delegates getting/finding resources to parent loader, which makes sense because + * the REPL never produce resources dynamically. One exception is when getting a Class file as + * resource stream, in which case we will try to fetch the Class file in the same way as loading + * the class, so that dynamically generated Classes from the REPL can be picked up. * * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or * the load failed, that it will call the overridden `findClass` function. To avoid the potential @@ -71,6 +74,30 @@ class ExecutorClassLoader( parentLoader.getResources(name) } + override def getResourceAsStream(name: String): InputStream = { + if (userClassPathFirst) { + val res = getClassResourceAsStreamLocally(name) + if (res != null) res else parentLoader.getResourceAsStream(name) + } else { + val res = parentLoader.getResourceAsStream(name) + if (res != null) res else getClassResourceAsStreamLocally(name) + } + } + + private def getClassResourceAsStreamLocally(name: String): InputStream = { + // Class files can be dynamically generated from the REPL. Allow this class loader to + // load such files for purposes other than loading the class. + try { + if (name.endsWith(".class")) fetchFn(name) else null + } catch { + // The helper functions referenced by fetchFn throw CNFE to indicate failure to fetch + // the class. It matches what IOException was supposed to be used for, and + // ClassLoader.getResourceAsStream() catches IOException and returns null in that case. + // So we follow that model and handle CNFE here. + case _: ClassNotFoundException => null + } + } + override def findClass(name: String): Class[_] = { if (userClassPathFirst) { findClassLocally(name).getOrElse(parentLoader.loadClass(name)) diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index e9ed01ff22..4752495e8e 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -208,6 +208,17 @@ class ExecutorClassLoaderSuite intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance() } + + // classLoader.getResourceAsStream() should also be able to fetch the Class file + val fakeClassInputStream = classLoader.getResourceAsStream("ReplFakeClass2.class") + try { + val magic = new Array[Byte](4) + fakeClassInputStream.read(magic) + // first 4 bytes should match the magic number of Class file + assert(magic === Array[Byte](0xCA.toByte, 0xFE.toByte, 0xBA.toByte, 0xBE.toByte)) + } finally { + if (fakeClassInputStream != null) fakeClassInputStream.close() + } } } diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 4f3df72917..a46cb6b3f4 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -260,4 +260,36 @@ class ReplSuite extends SparkFunSuite { assertContains("!!2!!", output2) } + test("SPARK-26633: ExecutorClassLoader.getResourceAsStream find REPL classes") { + val output = runInterpreterInPasteMode("local-cluster[1,1,1024]", + """ + |case class TestClass(value: Int) + | + |sc.parallelize(1 to 1).map { _ => + | val clz = classOf[TestClass] + | val name = clz.getName.replace('.', '/') + ".class"; + | val stream = clz.getClassLoader.getResourceAsStream(name) + | if (stream == null) { + | "failed: stream is null" + | } else { + | val magic = new Array[Byte](4) + | try { + | stream.read(magic) + | // the magic number of a Java Class file + | val expected = Array[Byte](0xCA.toByte, 0xFE.toByte, 0xBA.toByte, 0xBE.toByte) + | if (magic sameElements expected) { + | "successful" + | } else { + | "failed: unexpected contents from stream" + | } + | } finally { + | stream.close() + | } + | } + |}.collect() + """.stripMargin) + assertDoesNotContain("failed", output) + assertContains("successful", output) + } + }