[SPARK-30090][SHELL] Adapt Spark REPL to Scala 2.13

### What changes were proposed in this pull request?

This is an attempt to adapt Spark REPL to Scala 2.13.

It is based on a [scala-2.13 branch](https://github.com/smarter/spark/tree/scala-2.13) made by smarter.

I had to set Scala version to 2.13 in some places, and to adapt some other modules, before I could start working on the REPL itself. These are separate commits on the branch that probably would be fixed beforehand, and thus dropped before the merge of this PR.

I couldn't find a way to run the initialization code with existing REPL classes in Scala 2.13.2, so I [modified REPL in Scala](e9cc0dd547) to make it work. With this modification I managed to run Spark Shell, along with the units tests passing, which is good news.

The bad news is that it requires an upstream change in Scala, which must be accepted first. I'd be happy to change it if someone points a way to do it differently. If not, I'd propose a PR in Scala to introduce `ILoop.internalReplAutorunCode`.

### Why are the changes needed?

REPL in Scala changed quite a lot, so current version of Spark REPL needed to be adapted.

### Does this PR introduce _any_ user-facing change?

In the previous version of `SparkILoop`, a lot of Scala's `ILoop` code was [overridden and duplicated](2bc7b75537) to make the welcome message a bit more pleasant. In this PR, the message is in a bit different order, but it's still acceptable IMHO.

Before this PR:
```
20/05/15 15:32:39 WARN Utils: Your hostname, hermes resolves to a loopback address: 127.0.1.1; using 192.168.1.28 instead (on interface enp0s31f6)
20/05/15 15:32:39 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
20/05/15 15:32:39 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
20/05/15 15:32:45 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
Spark context Web UI available at http://192.168.1.28:4041
Spark context available as 'sc' (master = local[*], app id = local-1589549565502).
Spark session available as 'spark'.
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 3.0.1-SNAPSHOT
      /_/

Using Scala version 2.12.10 (OpenJDK 64-Bit Server VM, Java 1.8.0_242)
Type in expressions to have them evaluated.
Type :help for more information.

scala>
```

With this PR:
```
20/05/15 15:32:15 WARN Utils: Your hostname, hermes resolves to a loopback address: 127.0.1.1; using 192.168.1.28 instead (on interface enp0s31f6)
20/05/15 15:32:15 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
20/05/15 15:32:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 3.0.0-SNAPSHOT
      /_/

Using Scala version 2.13.2-20200422-211118-706ef1b (OpenJDK 64-Bit Server VM, Java 1.8.0_242)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context Web UI available at http://192.168.1.28:4040
Spark context available as 'sc' (master = local[*], app id = local-1589549541259).
Spark session available as 'spark'.

scala>
```

It seems that currently the welcoming message is still an improvement from [the original ticket](https://issues.apache.org/jira/browse/SPARK-24785), albeit in a different order. As a bonus, some fragile code duplication was removed.

### How was this patch tested?

Existing tests pass in `repl`module. The REPL runs in a terminal and the following code executed correctly:

```
scala> spark.range(1000 * 1000 * 1000).count()
val res0: Long = 1000000000
```

Closes #28545 from karolchmist/scala-2.13-repl.

Authored-by: Karol Chmist <info+github@chmist.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
Karol Chmist 2020-09-12 18:15:15 -05:00 committed by Sean Owen
parent bbbd907780
commit 3be552ccc8
11 changed files with 741 additions and 89 deletions

View file

@ -0,0 +1,138 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.repl
import java.io.File
import java.net.URI
import java.util.Locale
import scala.tools.nsc.GenericRunnerSettings
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.Utils
object Main extends Logging {
initializeLogIfNecessary(true)
Signaling.cancelOnInterrupt()
val conf = new SparkConf()
val rootDir =
conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf))
val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl")
var sparkContext: SparkContext = _
var sparkSession: SparkSession = _
// this is a public var because tests reset it.
var interp: SparkILoop = _
private var hasErrors = false
private var isShellSession = false
private def scalaOptionError(msg: String): Unit = {
hasErrors = true
// scalastyle:off println
Console.err.println(msg)
// scalastyle:on println
}
def main(args: Array[String]): Unit = {
isShellSession = true
doMain(args, new SparkILoop)
}
// Visible for testing
private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = {
interp = _interp
val jars = Utils
.getLocalUserJarsForShell(conf)
// Remove file:///, file:// or file:/ scheme if exists for each jar
.map { x =>
if (x.startsWith("file:")) new File(new URI(x)).getPath else x
}
.mkString(File.pathSeparator)
val interpArguments = List(
"-Yrepl-class-based",
"-Yrepl-outdir",
s"${outputDir.getAbsolutePath}",
"-classpath",
jars
) ++ args.toList
val settings = new GenericRunnerSettings(scalaOptionError)
settings.processArguments(interpArguments, true)
if (!hasErrors) {
interp.run(settings) // Repl starts and goes in loop of R.E.P.L
Option(sparkContext).foreach(_.stop)
}
}
def createSparkSession(): SparkSession = {
try {
val execUri = System.getenv("SPARK_EXECUTOR_URI")
conf.setIfMissing("spark.app.name", "Spark shell")
// SparkContext will detect this configuration and register it with the RpcEnv's
// file server, setting spark.repl.class.uri to the actual URI for executors to
// use. This is sort of ugly but since executors are started as part of SparkContext
// initialization in certain cases, there's an initialization order issue that prevents
// this from being set after SparkContext is instantiated.
conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath())
if (execUri != null) {
conf.set("spark.executor.uri", execUri)
}
if (System.getenv("SPARK_HOME") != null) {
conf.setSparkHome(System.getenv("SPARK_HOME"))
}
val builder = SparkSession.builder.config(conf)
if (conf
.get(CATALOG_IMPLEMENTATION.key, "hive")
.toLowerCase(Locale.ROOT) == "hive") {
if (SparkSession.hiveClassesArePresent) {
// In the case that the property is not set at all, builder's config
// does not have this value set to 'hive' yet. The original default
// behavior is that when there are hive classes, we use hive catalog.
sparkSession = builder.enableHiveSupport().getOrCreate()
logInfo("Created Spark session with Hive support")
} else {
// Need to change it back to 'in-memory' if no hive classes are found
// in the case that the property is set to hive in spark-defaults.conf
builder.config(CATALOG_IMPLEMENTATION.key, "in-memory")
sparkSession = builder.getOrCreate()
logInfo("Created Spark session")
}
} else {
// In the case that the property is set but not to 'hive', the internal
// default is 'in-memory'. So the sparkSession will use in-memory catalog.
sparkSession = builder.getOrCreate()
logInfo("Created Spark session")
}
sparkContext = sparkSession.sparkContext
sparkSession
} catch {
case e: Exception if isShellSession =>
logError("Failed to initialize Spark session.", e)
sys.exit(1)
}
}
}

View file

@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.repl
import java.io.{BufferedReader, PrintWriter}
// scalastyle:off println
import scala.Predef.{println => _, _}
import scala.tools.nsc.GenericRunnerSettings
import scala.tools.nsc.Settings
import scala.tools.nsc.interpreter.shell.{ILoop, ShellConfig}
import scala.tools.nsc.util.stringFromStream
import scala.util.Properties.{javaVersion, javaVmName, versionString}
// scalastyle:on println
/**
* A Spark-specific interactive shell.
*/
class SparkILoop(in0: BufferedReader, out: PrintWriter)
extends ILoop(ShellConfig(new GenericRunnerSettings(_ => ())), in0, out) {
def this() = this(null, new PrintWriter(Console.out, true))
val initializationCommands: Seq[String] = Seq(
"""
@transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) {
org.apache.spark.repl.Main.sparkSession
} else {
org.apache.spark.repl.Main.createSparkSession()
}
@transient val sc = {
val _sc = spark.sparkContext
if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) {
val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null)
if (proxyUrl != null) {
println(
s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}")
} else {
println(s"Spark Context Web UI is available at Spark Master Public URL")
}
} else {
_sc.uiWebUrl.foreach {
webUrl => println(s"Spark context Web UI available at ${webUrl}")
}
}
println("Spark context available as 'sc' " +
s"(master = ${_sc.master}, app id = ${_sc.applicationId}).")
println("Spark session available as 'spark'.")
_sc
}
""",
"import org.apache.spark.SparkContext._",
"import spark.implicits._",
"import spark.sql",
"import org.apache.spark.sql.functions._"
)
override protected def internalReplAutorunCode(): Seq[String] =
initializationCommands
def initializeSpark(): Unit = {
if (!intp.reporter.hasErrors) {
// `savingReplayStack` removes the commands from session history.
savingReplayStack {
initializationCommands.foreach(intp quietRun _)
}
} else {
throw new RuntimeException(
s"Scala $versionString interpreter encountered " +
"errors during initialization"
)
}
}
/** Print a welcome message */
override def printWelcome(): Unit = {
import org.apache.spark.SPARK_VERSION
echo("""Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\ version %s
/_/
""".format(SPARK_VERSION))
val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
versionString,
javaVmName,
javaVersion
)
echo(welcomeMsg)
echo("Type in expressions to have them evaluated.")
echo("Type :help for more information.")
}
/** Available commands */
override def commands: List[LoopCommand] = standardCommands
override def resetCommand(line: String): Unit = {
super.resetCommand(line)
initializeSpark()
echo(
"Note that after :reset, state of SparkSession and SparkContext is unchanged."
)
}
override def replay(): Unit = {
initializeSpark()
super.replay()
}
}
object SparkILoop {
/**
* Creates an interpreter loop with default settings and feeds
* the given code to it as input.
*/
def run(code: String, sets: Settings = new Settings): String = {
import java.io.{BufferedReader, StringReader, OutputStreamWriter}
stringFromStream { ostream =>
Console.withOut(ostream) {
val input = new BufferedReader(new StringReader(code))
val output = new PrintWriter(new OutputStreamWriter(ostream), true)
val repl = new SparkILoop(input, output)
if (sets.classpath.isDefault) {
sets.classpath.value = sys.props("java.class.path")
}
repl.run(sets)
}
}
}
def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString)
}

View file

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.repl
import java.io._
import java.nio.file.Files
import scala.tools.nsc.interpreter.SimpleReader
import org.apache.log4j.{Level, LogManager, PropertyConfigurator}
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
class Repl2Suite extends SparkFunSuite with BeforeAndAfterAll {
test("propagation of local properties") {
// A mock ILoop that doesn't install the SIGINT handler.
class ILoop(out: PrintWriter) extends SparkILoop(None, out) {
settings = new scala.tools.nsc.Settings
settings.usejavacp.value = true
org.apache.spark.repl.Main.interp = this
in = SimpleReader()
}
val out = new StringWriter()
Main.interp = new ILoop(new PrintWriter(out))
Main.sparkContext = new SparkContext("local", "repl-test")
Main.interp.createInterpreter()
Main.sparkContext.setLocalProperty("someKey", "someValue")
// Make sure the value we set in the caller to interpret is propagated in the thread that
// interprets the command.
Main.interp.interpret("org.apache.spark.repl.Main.sparkContext.getLocalProperty(\"someKey\")")
assert(out.toString.contains("someValue"))
Main.sparkContext.stop()
System.clearProperty("spark.driver.port")
}
}

View file

@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.repl
import java.io._
import org.apache.spark.SparkFunSuite
/**
* A special test suite for REPL that all test cases share one REPL instance.
*/
class SingletonRepl2Suite extends SparkFunSuite {
private val out = new StringWriter()
private val in = new PipedOutputStream()
private var thread: Thread = _
private val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
private val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH)
override def beforeAll(): Unit = {
super.beforeAll()
val classpath = System.getProperty("java.class.path")
System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath)
Main.conf.set("spark.master", "local-cluster[2,1,1024]")
val interp = new SparkILoop(
new BufferedReader(new InputStreamReader(new PipedInputStream(in))),
new PrintWriter(out))
// Forces to create new SparkContext
Main.sparkContext = null
Main.sparkSession = null
// Starts a new thread to run the REPL interpreter, so that we won't block.
thread = new Thread(() => Main.doMain(Array("-classpath", classpath), interp))
thread.setDaemon(true)
thread.start()
waitUntil(() => out.toString.contains("Type :help for more information"))
}
override def afterAll(): Unit = {
in.close()
thread.join()
if (oldExecutorClasspath != null) {
System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath)
} else {
System.clearProperty(CONF_EXECUTOR_CLASSPATH)
}
super.afterAll()
}
private def waitUntil(cond: () => Boolean): Unit = {
import scala.concurrent.duration._
import org.scalatest.concurrent.Eventually._
eventually(timeout(50.seconds), interval(500.millis)) {
assert(cond(), "current output: " + out.toString)
}
}
/**
* Run the given commands string in a globally shared interpreter instance. Note that the given
* commands should not crash the interpreter, to not affect other test cases.
*/
def runInterpreter(input: String): String = {
val currentOffset = out.getBuffer.length()
// append a special statement to the end of the given code, so that we can know what's
// the final output of this code snippet and rely on it to wait until the output is ready.
val timestamp = System.currentTimeMillis()
in.write((input + s"\nval _result_$timestamp = 1\n").getBytes)
in.flush()
val stopMessage = s"_result_$timestamp: Int = 1"
waitUntil(() => out.getBuffer.substring(currentOffset).contains(stopMessage))
out.getBuffer.substring(currentOffset)
}
def assertContains(message: String, output: String): Unit = {
val isContain = output.contains(message)
assert(isContain,
"Interpreter output did not contain '" + message + "':\n" + output)
}
def assertDoesNotContain(message: String, output: String): Unit = {
val isContain = output.contains(message)
assert(!isContain,
"Interpreter output contained '" + message + "':\n" + output)
}
test("SPARK-31399: should clone+clean line object w/ non-serializable state in ClosureCleaner") {
// Test ClosureCleaner when a closure captures the enclosing `this` REPL line object, and that
// object contains an unused non-serializable field.
// Specifically, the closure in this test case contains a directly nested closure, and the
// capture is triggered by the inner closure.
// `ns` should be nulled out, but `topLevelValue` should stay intact.
// Can't use :paste mode because PipedOutputStream/PipedInputStream doesn't work well with the
// EOT control character (i.e. Ctrl+D).
// Just write things on a single line to emulate :paste mode.
// NOTE: in order for this test case to trigger the intended scenario, the following three
// variables need to be in the same "input", which will make the REPL pack them into the
// same REPL line object:
// - ns: a non-serializable state, not accessed by the closure;
// - topLevelValue: a serializable state, accessed by the closure;
// - closure: the starting closure, captures the enclosing REPL line object.
val output = runInterpreter(
"""
|class NotSerializableClass(val x: Int)
|val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure =
|(j: Int) => {
| (1 to j).flatMap { x =>
| (1 to x).map { y => y + topLevelValue }
| }
|}
|val r = sc.parallelize(0 to 2).map(closure).collect
""".stripMargin)
assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " +
"Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
// assertContains("r: Array[IndexedSeq[String]] = " +
// "Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
assertDoesNotContain("Exception", output)
}
test("SPARK-31399: ClosureCleaner should discover indirectly nested closure in inner class") {
// Similar to the previous test case, but with indirect closure nesting instead.
// There's still nested closures involved, but the inner closure is indirectly nested in the
// outer closure, with a level of inner class in between them.
// This changes how the inner closure references/captures the outer closure/enclosing `this`
// REPL line object, and covers a different code path in inner closure discovery.
// `ns` should be nulled out, but `topLevelValue` should stay intact.
val output = runInterpreter(
"""
|class NotSerializableClass(val x: Int)
|val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure =
|(j: Int) => {
| class InnerFoo {
| val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue }
| }
| val innerFoo = new InnerFoo
| (1 to j).flatMap(innerFoo.innerClosure)
|}
|val r = sc.parallelize(0 to 2).map(closure).collect
""".stripMargin)
assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " +
"Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
// assertContains("r: Array[IndexedSeq[String]] = " +
// "Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
assertDoesNotContain("Array(Vector(), Vector(1null), Vector(1null, 1null, 2null)", output)
assertDoesNotContain("Exception", output)
}
}

View file

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.repl
import java.io._
import java.nio.file.Files
import org.apache.log4j.{Level, LogManager, PropertyConfigurator}
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
class Repl2Suite extends SparkFunSuite with BeforeAndAfterAll {
test("propagation of local properties") {
// A mock ILoop that doesn't install the SIGINT handler.
class ILoop(out: PrintWriter) extends SparkILoop(null, out)
val out = new StringWriter()
Main.interp = new ILoop(new PrintWriter(out))
Main.sparkContext = new SparkContext("local", "repl-test")
val settings = new scala.tools.nsc.Settings
settings.usejavacp.value = true
Main.interp.createInterpreter(settings)
Main.sparkContext.setLocalProperty("someKey", "someValue")
// Make sure the value we set in the caller to interpret is propagated in the thread that
// interprets the command.
Main.interp.interpret("org.apache.spark.repl.Main.sparkContext.getLocalProperty(\"someKey\")")
assert(out.toString.contains("someValue"))
Main.sparkContext.stop()
System.clearProperty("spark.driver.port")
}
}

View file

@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.repl
import java.io._
import org.apache.spark.SparkFunSuite
/**
* A special test suite for REPL that all test cases share one REPL instance.
*/
class SingletonRepl2Suite extends SparkFunSuite {
private val out = new StringWriter()
private val in = new PipedOutputStream()
private var thread: Thread = _
private val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
private val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH)
override def beforeAll(): Unit = {
super.beforeAll()
val classpath = System.getProperty("java.class.path")
System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath)
Main.conf.set("spark.master", "local-cluster[2,1,1024]")
val interp = new SparkILoop(
new BufferedReader(new InputStreamReader(new PipedInputStream(in))),
new PrintWriter(out))
// Forces to create new SparkContext
Main.sparkContext = null
Main.sparkSession = null
// Starts a new thread to run the REPL interpreter, so that we won't block.
thread = new Thread(() => Main.doMain(Array("-classpath", classpath), interp))
thread.setDaemon(true)
thread.start()
waitUntil(() => out.toString.contains("Type :help for more information"))
}
override def afterAll(): Unit = {
in.close()
thread.join()
if (oldExecutorClasspath != null) {
System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath)
} else {
System.clearProperty(CONF_EXECUTOR_CLASSPATH)
}
super.afterAll()
}
private def waitUntil(cond: () => Boolean): Unit = {
import scala.concurrent.duration._
import org.scalatest.concurrent.Eventually._
eventually(timeout(50.seconds), interval(500.millis)) {
assert(cond(), "current output: " + out.toString)
}
}
/**
* Run the given commands string in a globally shared interpreter instance. Note that the given
* commands should not crash the interpreter, to not affect other test cases.
*/
def runInterpreter(input: String): String = {
val currentOffset = out.getBuffer.length()
// append a special statement to the end of the given code, so that we can know what's
// the final output of this code snippet and rely on it to wait until the output is ready.
val timestamp = System.currentTimeMillis()
in.write((input + s"\nval _result_$timestamp = 1\n").getBytes)
in.flush()
val stopMessage = s"_result_$timestamp: Int = 1"
waitUntil(() => out.getBuffer.substring(currentOffset).contains(stopMessage))
out.getBuffer.substring(currentOffset)
}
def assertContains(message: String, output: String): Unit = {
val isContain = output.contains(message)
assert(isContain,
"Interpreter output did not contain '" + message + "':\n" + output)
}
def assertDoesNotContain(message: String, output: String): Unit = {
val isContain = output.contains(message)
assert(!isContain,
"Interpreter output contained '" + message + "':\n" + output)
}
test("SPARK-31399: should clone+clean line object w/ non-serializable state in ClosureCleaner") {
// Test ClosureCleaner when a closure captures the enclosing `this` REPL line object, and that
// object contains an unused non-serializable field.
// Specifically, the closure in this test case contains a directly nested closure, and the
// capture is triggered by the inner closure.
// `ns` should be nulled out, but `topLevelValue` should stay intact.
// Can't use :paste mode because PipedOutputStream/PipedInputStream doesn't work well with the
// EOT control character (i.e. Ctrl+D).
// Just write things on a single line to emulate :paste mode.
// NOTE: in order for this test case to trigger the intended scenario, the following three
// variables need to be in the same "input", which will make the REPL pack them into the
// same REPL line object:
// - ns: a non-serializable state, not accessed by the closure;
// - topLevelValue: a serializable state, accessed by the closure;
// - closure: the starting closure, captures the enclosing REPL line object.
val output = runInterpreter(
"""
|class NotSerializableClass(val x: Int)
|val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure =
|(j: Int) => {
| (1 to j).flatMap { x =>
| (1 to x).map { y => y + topLevelValue }
| }
|}
|val r = sc.parallelize(0 to 2).map(closure).collect
""".stripMargin)
// assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " +
// "Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
assertContains("r: Array[IndexedSeq[String]] = " +
"Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
assertDoesNotContain("Exception", output)
}
test("SPARK-31399: ClosureCleaner should discover indirectly nested closure in inner class") {
// Similar to the previous test case, but with indirect closure nesting instead.
// There's still nested closures involved, but the inner closure is indirectly nested in the
// outer closure, with a level of inner class in between them.
// This changes how the inner closure references/captures the outer closure/enclosing `this`
// REPL line object, and covers a different code path in inner closure discovery.
// `ns` should be nulled out, but `topLevelValue` should stay intact.
val output = runInterpreter(
"""
|class NotSerializableClass(val x: Int)
|val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure =
|(j: Int) => {
| class InnerFoo {
| val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue }
| }
| val innerFoo = new InnerFoo
| (1 to j).flatMap(innerFoo.innerClosure)
|}
|val r = sc.parallelize(0 to 2).map(closure).collect
""".stripMargin)
// assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " +
// "Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
assertContains("r: Array[IndexedSeq[String]] = " +
"Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
assertDoesNotContain("Array(Vector(), Vector(1null), Vector(1null, 1null, 2null)", output)
assertDoesNotContain("Exception", output)
}
}

View file

@ -20,8 +20,6 @@ package org.apache.spark.repl
import java.io._
import java.nio.file.Files
import scala.tools.nsc.interpreter.SimpleReader
import org.apache.log4j.{Level, LogManager, PropertyConfigurator}
import org.scalatest.BeforeAndAfterAll
@ -86,31 +84,6 @@ class ReplSuite extends SparkFunSuite with BeforeAndAfterAll {
"Interpreter output contained '" + message + "':\n" + output)
}
test("propagation of local properties") {
// A mock ILoop that doesn't install the SIGINT handler.
class ILoop(out: PrintWriter) extends SparkILoop(None, out) {
settings = new scala.tools.nsc.Settings
settings.usejavacp.value = true
org.apache.spark.repl.Main.interp = this
in = SimpleReader()
}
val out = new StringWriter()
Main.interp = new ILoop(new PrintWriter(out))
Main.sparkContext = new SparkContext("local", "repl-test")
Main.interp.createInterpreter()
Main.sparkContext.setLocalProperty("someKey", "someValue")
// Make sure the value we set in the caller to interpret is propagated in the thread that
// interprets the command.
Main.interp.interpret("org.apache.spark.repl.Main.sparkContext.getLocalProperty(\"someKey\")")
assert(out.toString.contains("someValue"))
Main.sparkContext.stop()
System.clearProperty("spark.driver.port")
}
test("SPARK-15236: use Hive catalog") {
// turn on the INFO log so that it is possible the code will dump INFO
// entry for using "HiveMetastore"

View file

@ -380,67 +380,6 @@ class SingletonReplSuite extends SparkFunSuite {
assertDoesNotContain("Exception", output)
}
test("SPARK-31399: should clone+clean line object w/ non-serializable state in ClosureCleaner") {
// Test ClosureCleaner when a closure captures the enclosing `this` REPL line object, and that
// object contains an unused non-serializable field.
// Specifically, the closure in this test case contains a directly nested closure, and the
// capture is triggered by the inner closure.
// `ns` should be nulled out, but `topLevelValue` should stay intact.
// Can't use :paste mode because PipedOutputStream/PipedInputStream doesn't work well with the
// EOT control character (i.e. Ctrl+D).
// Just write things on a single line to emulate :paste mode.
// NOTE: in order for this test case to trigger the intended scenario, the following three
// variables need to be in the same "input", which will make the REPL pack them into the
// same REPL line object:
// - ns: a non-serializable state, not accessed by the closure;
// - topLevelValue: a serializable state, accessed by the closure;
// - closure: the starting closure, captures the enclosing REPL line object.
val output = runInterpreter(
"""
|class NotSerializableClass(val x: Int)
|val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure =
|(j: Int) => {
| (1 to j).flatMap { x =>
| (1 to x).map { y => y + topLevelValue }
| }
|}
|val r = sc.parallelize(0 to 2).map(closure).collect
""".stripMargin)
assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " +
"Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
assertDoesNotContain("Exception", output)
}
test("SPARK-31399: ClosureCleaner should discover indirectly nested closure in inner class") {
// Similar to the previous test case, but with indirect closure nesting instead.
// There's still nested closures involved, but the inner closure is indirectly nested in the
// outer closure, with a level of inner class in between them.
// This changes how the inner closure references/captures the outer closure/enclosing `this`
// REPL line object, and covers a different code path in inner closure discovery.
// `ns` should be nulled out, but `topLevelValue` should stay intact.
val output = runInterpreter(
"""
|class NotSerializableClass(val x: Int)
|val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure =
|(j: Int) => {
| class InnerFoo {
| val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue }
| }
| val innerFoo = new InnerFoo
| (1 to j).flatMap(innerFoo.innerClosure)
|}
|val r = sc.parallelize(0 to 2).map(closure).collect
""".stripMargin)
assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " +
"Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output)
assertDoesNotContain("Array(Vector(), Vector(1null), Vector(1null, 1null, 2null)", output)
assertDoesNotContain("Exception", output)
}
test("newProductSeqEncoder with REPL defined class") {
val output = runInterpreter(
"""

View file

@ -43,7 +43,7 @@ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Ma
new CaseInsensitiveMap[B1](originalMap.filter(!_._1.equalsIgnoreCase(key)) + (key -> value))
}
override def +[B1 >: T](kv: (String, B1)): CaseInsensitiveMap[B1] = this.updated(kv._1, kv._2)
override def +[B1 >: T](kv: (String, B1)): CaseInsensitiveMap[B1] = this.updated(kv._1, kv._2)
def ++(xs: IterableOnce[(String, T)]): CaseInsensitiveMap[T] = {
xs.iterator.foldLeft(this) { (m, kv) => m.updated(kv._1, kv._2) }