Merge pull request #535 from ScrapCodes/scala-2.10-repl-port

porting of repl to scala-2.10
This commit is contained in:
Matei Zaharia 2013-04-20 10:40:07 -07:00
commit 4b57f83209
16 changed files with 2095 additions and 1545 deletions

View file

@ -17,11 +17,11 @@ object SparkBuild extends Build {
//val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1" //val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1"
//val HADOOP_MAJOR_VERSION = "2" //val HADOOP_MAJOR_VERSION = "2"
lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, /*repl,*/ examples, bagel) lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel)
lazy val core = Project("core", file("core"), settings = coreSettings) lazy val core = Project("core", file("core"), settings = coreSettings)
// lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core)
lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core)
@ -35,7 +35,7 @@ object SparkBuild extends Build {
organization := "org.spark-project", organization := "org.spark-project",
version := "0.7.0-SNAPSHOT", version := "0.7.0-SNAPSHOT",
scalaVersion := "2.10.0", scalaVersion := "2.10.0",
scalacOptions := Seq(/*"-deprecation",*/ "-unchecked", "-optimize"), // -deprecation is too noisy due to usage of old Hadoop API, enable it once that's no longer an issue scalacOptions := Seq("-unchecked", "-optimize"),
unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath },
retrieveManaged := true, retrieveManaged := true,
retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
@ -136,7 +136,9 @@ object SparkBuild extends Build {
"io.spray" %% "spray-json" % "1.2.3", "io.spray" %% "spray-json" % "1.2.3",
"colt" % "colt" % "1.2.0", "colt" % "colt" % "1.2.0",
"org.apache.mesos" % "mesos" % "0.9.0-incubating", "org.apache.mesos" % "mesos" % "0.9.0-incubating",
"org.scala-lang" % "scala-actors" % "2.10.0" "org.scala-lang" % "scala-actors" % "2.10.0",
"org.scala-lang" % "jline" % "2.10.0",
"org.scala-lang" % "scala-reflect" % "2.10.0"
) ++ (if (HADOOP_MAJOR_VERSION == "2") ) ++ (if (HADOOP_MAJOR_VERSION == "2")
Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq,
unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") }
@ -146,10 +148,11 @@ object SparkBuild extends Build {
publish := {} publish := {}
) )
/* def replSettings = sharedSettings ++ Seq( def replSettings = sharedSettings ++ Seq(
name := "spark-repl", name := "spark-repl",
libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _) // libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _)
)*/ libraryDependencies ++= Seq("org.scala-lang" % "scala-compiler" % "2.10.0")
)
def examplesSettings = sharedSettings ++ Seq( def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples" name := "spark-examples"

Binary file not shown.

View file

@ -3,7 +3,7 @@ package spark.repl
import scala.collection.mutable.Set import scala.collection.mutable.Set
object Main { object Main {
private var _interp: SparkILoop = null private var _interp: SparkILoop = _
def interp = _interp def interp = _interp

View file

@ -0,0 +1,109 @@
/* NSC -- new Scala compiler
* Copyright 2005-2013 LAMP/EPFL
* @author Paul Phillips
*/
package spark.repl
import scala.tools.nsc._
import scala.tools.nsc.interpreter._
import scala.reflect.internal.util.BatchSourceFile
import scala.tools.nsc.ast.parser.Tokens.EOF
import spark.Logging
trait SparkExprTyper extends Logging {
val repl: SparkIMain
import repl._
import global.{ reporter => _, Import => _, _ }
import definitions._
import syntaxAnalyzer.{ UnitParser, UnitScanner, token2name }
import naming.freshInternalVarName
object codeParser extends { val global: repl.global.type = repl.global } with CodeHandlers[Tree] {
def applyRule[T](code: String, rule: UnitParser => T): T = {
reporter.reset()
val scanner = newUnitParser(code)
val result = rule(scanner)
if (!reporter.hasErrors)
scanner.accept(EOF)
result
}
def defns(code: String) = stmts(code) collect { case x: DefTree => x }
def expr(code: String) = applyRule(code, _.expr())
def stmts(code: String) = applyRule(code, _.templateStats())
def stmt(code: String) = stmts(code).last // guaranteed nonempty
}
/** Parse a line into a sequence of trees. Returns None if the input is incomplete. */
def parse(line: String): Option[List[Tree]] = debugging(s"""parse("$line")""") {
var isIncomplete = false
reporter.withIncompleteHandler((_, _) => isIncomplete = true) {
val trees = codeParser.stmts(line)
if (reporter.hasErrors) Some(Nil)
else if (isIncomplete) None
else Some(trees)
}
}
// def parsesAsExpr(line: String) = {
// import codeParser._
// (opt expr line).isDefined
// }
def symbolOfLine(code: String): Symbol = {
def asExpr(): Symbol = {
val name = freshInternalVarName()
// Typing it with a lazy val would give us the right type, but runs
// into compiler bugs with things like existentials, so we compile it
// behind a def and strip the NullaryMethodType which wraps the expr.
val line = "def " + name + " = {\n" + code + "\n}"
interpretSynthetic(line) match {
case IR.Success =>
val sym0 = symbolOfTerm(name)
// drop NullaryMethodType
val sym = sym0.cloneSymbol setInfo afterTyper(sym0.info.finalResultType)
if (sym.info.typeSymbol eq UnitClass) NoSymbol
else sym
case _ => NoSymbol
}
}
def asDefn(): Symbol = {
val old = repl.definedSymbolList.toSet
interpretSynthetic(code) match {
case IR.Success =>
repl.definedSymbolList filterNot old match {
case Nil => NoSymbol
case sym :: Nil => sym
case syms => NoSymbol.newOverloaded(NoPrefix, syms)
}
case _ => NoSymbol
}
}
beQuietDuring(asExpr()) orElse beQuietDuring(asDefn())
}
private var typeOfExpressionDepth = 0
def typeOfExpression(expr: String, silent: Boolean = true): Type = {
if (typeOfExpressionDepth > 2) {
logDebug("Terminating typeOfExpression recursion for expression: " + expr)
return NoType
}
typeOfExpressionDepth += 1
// Don't presently have a good way to suppress undesirable success output
// while letting errors through, so it is first trying it silently: if there
// is an error, and errors are desired, then it re-evaluates non-silently
// to induce the error message.
try beSilentDuring(symbolOfLine(expr).tpe) match {
case NoType if !silent => symbolOfLine(expr).tpe // generate error
case tpe => tpe
}
finally typeOfExpressionDepth -= 1
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,143 @@
/* NSC -- new Scala compiler
* Copyright 2005-2013 LAMP/EPFL
* @author Paul Phillips
*/
package spark.repl
import scala.tools.nsc._
import scala.tools.nsc.interpreter._
import scala.reflect.internal.util.Position
import scala.util.control.Exception.ignoring
import scala.tools.nsc.util.stackTraceString
/**
* Machinery for the asynchronous initialization of the repl.
*/
trait SparkILoopInit {
self: SparkILoop =>
/** Print a welcome message */
def printWelcome() {
echo("""Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\ version 0.7.1-SNAPSHOT
/_/
""")
import Properties._
val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
versionString, javaVmName, javaVersion)
echo(welcomeMsg)
echo("Type in expressions to have them evaluated.")
echo("Type :help for more information.")
}
protected def asyncMessage(msg: String) {
if (isReplInfo || isReplPower)
echoAndRefresh(msg)
}
private val initLock = new java.util.concurrent.locks.ReentrantLock()
private val initCompilerCondition = initLock.newCondition() // signal the compiler is initialized
private val initLoopCondition = initLock.newCondition() // signal the whole repl is initialized
private val initStart = System.nanoTime
private def withLock[T](body: => T): T = {
initLock.lock()
try body
finally initLock.unlock()
}
// a condition used to ensure serial access to the compiler.
@volatile private var initIsComplete = false
@volatile private var initError: String = null
private def elapsed() = "%.3f".format((System.nanoTime - initStart).toDouble / 1000000000L)
// the method to be called when the interpreter is initialized.
// Very important this method does nothing synchronous (i.e. do
// not try to use the interpreter) because until it returns, the
// repl's lazy val `global` is still locked.
protected def initializedCallback() = withLock(initCompilerCondition.signal())
// Spins off a thread which awaits a single message once the interpreter
// has been initialized.
protected def createAsyncListener() = {
io.spawn {
withLock(initCompilerCondition.await())
asyncMessage("[info] compiler init time: " + elapsed() + " s.")
postInitialization()
}
}
// called from main repl loop
protected def awaitInitialized(): Boolean = {
if (!initIsComplete)
withLock { while (!initIsComplete) initLoopCondition.await() }
if (initError != null) {
println("""
|Failed to initialize the REPL due to an unexpected error.
|This is a bug, please, report it along with the error diagnostics printed below.
|%s.""".stripMargin.format(initError)
)
false
} else true
}
// private def warningsThunks = List(
// () => intp.bind("lastWarnings", "" + typeTag[List[(Position, String)]], intp.lastWarnings _),
// )
protected def postInitThunks = List[Option[() => Unit]](
Some(intp.setContextClassLoader _),
if (isReplPower) Some(() => enablePowerMode(true)) else None
).flatten
// ++ (
// warningsThunks
// )
// called once after init condition is signalled
protected def postInitialization() {
try {
postInitThunks foreach (f => addThunk(f()))
runThunks()
} catch {
case ex: Throwable =>
initError = stackTraceString(ex)
throw ex
} finally {
initIsComplete = true
if (isAsync) {
asyncMessage("[info] total init time: " + elapsed() + " s.")
withLock(initLoopCondition.signal())
}
}
}
def initializeSpark() {
intp.beQuietDuring {
command("""
@transient val sc = spark.repl.Main.interp.createSparkContext();
""")
command("import spark.SparkContext._");
}
echo("Spark context available as sc.")
}
// code to be executed only after the interpreter is initialized
// and the lazy val `global` can be accessed without risk of deadlock.
private var pendingThunks: List[() => Unit] = Nil
protected def addThunk(body: => Unit) = synchronized {
pendingThunks :+= (() => body)
}
protected def runThunks(): Unit = synchronized {
if (pendingThunks.nonEmpty)
logDebug("Clearing " + pendingThunks.size + " thunks.")
while (pendingThunks.nonEmpty) {
val thunk = pendingThunks.head
pendingThunks = pendingThunks.tail
thunk()
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,63 +0,0 @@
/* NSC -- new Scala compiler
* Copyright 2005-2011 LAMP/EPFL
* @author Alexander Spoon
*/
package spark.repl
import scala.tools.nsc._
import scala.tools.nsc.interpreter._
/** Settings for the interpreter
*
* @version 1.0
* @author Lex Spoon, 2007/3/24
**/
class SparkISettings(intp: SparkIMain) {
/** A list of paths where :load should look */
var loadPath = List(".")
/** Set this to true to see repl machinery under -Yrich-exceptions.
*/
var showInternalStackTraces = false
/** The maximum length of toString to use when printing the result
* of an evaluation. 0 means no maximum. If a printout requires
* more than this number of characters, then the printout is
* truncated.
*/
var maxPrintString = 800
/** The maximum number of completion candidates to print for tab
* completion without requiring confirmation.
*/
var maxAutoprintCompletion = 250
/** String unwrapping can be disabled if it is causing issues.
* Settings this to false means you will see Strings like "$iw.$iw.".
*/
var unwrapStrings = true
def deprecation_=(x: Boolean) = {
val old = intp.settings.deprecation.value
intp.settings.deprecation.value = x
if (!old && x) println("Enabled -deprecation output.")
else if (old && !x) println("Disabled -deprecation output.")
}
def deprecation: Boolean = intp.settings.deprecation.value
def allSettings = Map(
"maxPrintString" -> maxPrintString,
"maxAutoprintCompletion" -> maxAutoprintCompletion,
"unwrapStrings" -> unwrapStrings,
"deprecation" -> deprecation
)
private def allSettingsString =
allSettings.toList sortBy (_._1) map { case (k, v) => " " + k + " = " + v + "\n" } mkString
override def toString = """
| SparkISettings {
| %s
| }""".stripMargin.format(allSettingsString)
}

View file

@ -1,9 +1,10 @@
/* NSC -- new Scala compiler /* NSC -- new Scala compiler
* Copyright 2005-2011 LAMP/EPFL * Copyright 2005-2013 LAMP/EPFL
* @author Paul Phillips * @author Paul Phillips
*/ */
package spark.repl package spark
package repl
import scala.tools.nsc._ import scala.tools.nsc._
import scala.tools.nsc.interpreter._ import scala.tools.nsc.interpreter._
@ -17,12 +18,15 @@ trait SparkImports {
import definitions.{ ScalaPackage, JavaLangPackage, PredefModule } import definitions.{ ScalaPackage, JavaLangPackage, PredefModule }
import memberHandlers._ import memberHandlers._
def isNoImports = settings.noimports.value
def isNoPredef = settings.nopredef.value
/** Synthetic import handlers for the language defined imports. */ /** Synthetic import handlers for the language defined imports. */
private def makeWildcardImportHandler(sym: Symbol): ImportHandler = { private def makeWildcardImportHandler(sym: Symbol): ImportHandler = {
val hd :: tl = sym.fullName.split('.').toList map newTermName val hd :: tl = sym.fullName.split('.').toList map newTermName
val tree = Import( val tree = Import(
tl.foldLeft(Ident(hd): Tree)((x, y) => Select(x, y)), tl.foldLeft(Ident(hd): Tree)((x, y) => Select(x, y)),
List(ImportSelector(nme.WILDCARD, -1, null, -1)) ImportSelector.wildList
) )
tree setSymbol sym tree setSymbol sym
new ImportHandler(tree) new ImportHandler(tree)
@ -33,8 +37,9 @@ trait SparkImports {
def languageWildcards: List[Type] = languageWildcardSyms map (_.tpe) def languageWildcards: List[Type] = languageWildcardSyms map (_.tpe)
def languageWildcardHandlers = languageWildcardSyms map makeWildcardImportHandler def languageWildcardHandlers = languageWildcardSyms map makeWildcardImportHandler
def importedTerms = onlyTerms(importHandlers flatMap (_.importedNames)) def allImportedNames = importHandlers flatMap (_.importedNames)
def importedTypes = onlyTypes(importHandlers flatMap (_.importedNames)) def importedTerms = onlyTerms(allImportedNames)
def importedTypes = onlyTypes(allImportedNames)
/** Types which have been wildcard imported, such as: /** Types which have been wildcard imported, such as:
* val x = "abc" ; import x._ // type java.lang.String * val x = "abc" ; import x._ // type java.lang.String
@ -48,10 +53,7 @@ trait SparkImports {
* into the compiler scopes. * into the compiler scopes.
*/ */
def sessionWildcards: List[Type] = { def sessionWildcards: List[Type] = {
importHandlers flatMap { importHandlers filter (_.importsWildcard) map (_.targetType) distinct
case x if x.importsWildcard => x.targetType
case _ => None
} distinct
} }
def wildcardTypes = languageWildcards ++ sessionWildcards def wildcardTypes = languageWildcards ++ sessionWildcards
@ -62,14 +64,15 @@ trait SparkImports {
def importedTypeSymbols = importedSymbols collect { case x: TypeSymbol => x } def importedTypeSymbols = importedSymbols collect { case x: TypeSymbol => x }
def implicitSymbols = importedSymbols filter (_.isImplicit) def implicitSymbols = importedSymbols filter (_.isImplicit)
def importedTermNamed(name: String) = importedTermSymbols find (_.name.toString == name) def importedTermNamed(name: String): Symbol =
importedTermSymbols find (_.name.toString == name) getOrElse NoSymbol
/** Tuples of (source, imported symbols) in the order they were imported. /** Tuples of (source, imported symbols) in the order they were imported.
*/ */
def importedSymbolsBySource: List[(Symbol, List[Symbol])] = { def importedSymbolsBySource: List[(Symbol, List[Symbol])] = {
val lang = languageWildcardSyms map (sym => (sym, membersAtPickler(sym))) val lang = languageWildcardSyms map (sym => (sym, membersAtPickler(sym)))
val session = importHandlers filter (_.targetType.isDefined) map { mh => val session = importHandlers filter (_.targetType != NoType) map { mh =>
(mh.targetType.get.typeSymbol, mh.importedSymbols) (mh.targetType.typeSymbol, mh.importedSymbols)
} }
lang ++ session lang ++ session
@ -90,7 +93,7 @@ trait SparkImports {
* 2. A code fragment that should go after the code * 2. A code fragment that should go after the code
* of the new request. * of the new request.
* *
* 3. An access path which can be traverested to access * 3. An access path which can be traversed to access
* any bindings inside code wrapped by #1 and #2 . * any bindings inside code wrapped by #1 and #2 .
* *
* The argument is a set of Names that need to be imported. * The argument is a set of Names that need to be imported.
@ -103,8 +106,9 @@ trait SparkImports {
* (3) It imports multiple same-named implicits, but only the * (3) It imports multiple same-named implicits, but only the
* last one imported is actually usable. * last one imported is actually usable.
*/ */
case class ComputedImports(prepend: String, append: String, access: String) case class SparkComputedImports(prepend: String, append: String, access: String)
protected def importsCode(wanted: Set[Name]): ComputedImports = {
protected def importsCode(wanted: Set[Name]): SparkComputedImports = {
/** Narrow down the list of requests from which imports /** Narrow down the list of requests from which imports
* should be taken. Removes requests which cannot contribute * should be taken. Removes requests which cannot contribute
* useful imports for the specified set of wanted names. * useful imports for the specified set of wanted names.
@ -116,12 +120,11 @@ trait SparkImports {
* 'wanted' is the set of names that need to be imported. * 'wanted' is the set of names that need to be imported.
*/ */
def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = { def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = {
val isWanted = wanted contains _
// Single symbol imports might be implicits! See bug #1752. Rather than // Single symbol imports might be implicits! See bug #1752. Rather than
// try to finesse this, we will mimic all imports for now. // try to finesse this, we will mimic all imports for now.
def keepHandler(handler: MemberHandler) = handler match { def keepHandler(handler: MemberHandler) = handler match {
case _: ImportHandler => true case _: ImportHandler => true
case x => x.definesImplicit || (x.definedNames exists isWanted) case x => x.definesImplicit || (x.definedNames exists wanted)
} }
reqs match { reqs match {
@ -149,6 +152,11 @@ trait SparkImports {
accessPath append ("." + impname) accessPath append ("." + impname)
currentImps.clear currentImps.clear
// code append "object %s {\n".format(impname)
// trailingBraces append "}\n"
// accessPath append ("." + impname)
// currentImps.clear
} }
addWrapper() addWrapper()
@ -159,7 +167,7 @@ trait SparkImports {
// If the user entered an import, then just use it; add an import wrapping // If the user entered an import, then just use it; add an import wrapping
// level if the import might conflict with some other import // level if the import might conflict with some other import
case x: ImportHandler => case x: ImportHandler =>
if (x.importsWildcard || (currentImps exists (x.importedNames contains _))) if (x.importsWildcard || currentImps.exists(x.importedNames contains _))
addWrapper() addWrapper()
code append (x.member + "\n") code append (x.member + "\n")
@ -175,20 +183,12 @@ trait SparkImports {
// handle quoting keywords separately. // handle quoting keywords separately.
case x => case x =>
for (imv <- x.definedNames) { for (imv <- x.definedNames) {
// MATEI: Changed this check because it was messing up for case classes if (currentImps contains imv) addWrapper()
// (trying to import them twice within the same wrapper), but that is more likely
// due to a miscomputation of names that makes the code think they're unique.
// Need to evaluate whether having so many wrappers is a bad thing.
/*if (currentImps contains imv)*/
val imvName = imv.toString
if (currentImps exists (_.toString == imvName)) addWrapper()
val objName = req.lineRep.readPath val objName = req.lineRep.readPath
val valName = "$VAL" + newValId(); val valName = "$VAL" + newValId();
code.append("val " + valName + " = " + objName + ".INSTANCE;\n") code.append("val " + valName + " = " + objName + ".INSTANCE;\n")
code.append("import " + valName + req.accessPath + ".`" + imv + "`;\n") code.append("import " + valName + req.accessPath + ".`" + imv + "`;\n")
// code append ("import " + (req fullPath imv) + "\n")
//code append ("import %s\n" format (req fullPath imv))
currentImps += imv currentImps += imv
} }
} }
@ -196,14 +196,14 @@ trait SparkImports {
// add one extra wrapper, to prevent warnings in the common case of // add one extra wrapper, to prevent warnings in the common case of
// redefining the value bound in the last interpreter request. // redefining the value bound in the last interpreter request.
addWrapper() addWrapper()
ComputedImports(code.toString, trailingBraces.toString, accessPath.toString) SparkComputedImports(code.toString, trailingBraces.toString, accessPath.toString)
} }
private def allReqAndHandlers = private def allReqAndHandlers =
prevRequestList flatMap (req => req.handlers map (req -> _)) prevRequestList flatMap (req => req.handlers map (req -> _))
private def membersAtPickler(sym: Symbol): List[Symbol] = private def membersAtPickler(sym: Symbol): List[Symbol] =
atPickler(sym.info.nonPrivateMembers) beforePickler(sym.info.nonPrivateMembers.toList)
private var curValId = 0 private var curValId = 0

View file

@ -1,9 +1,11 @@
/* NSC -- new Scala compiler /* NSC -- new Scala compiler
* Copyright 2005-2011 LAMP/EPFL * Copyright 2005-2013 LAMP/EPFL
* @author Paul Phillips * @author Paul Phillips
*/ */
package spark.repl package spark
package repl
import scala.tools.nsc._ import scala.tools.nsc._
import scala.tools.nsc.interpreter._ import scala.tools.nsc.interpreter._
@ -11,29 +13,30 @@ import scala.tools.nsc.interpreter._
import scala.tools.jline._ import scala.tools.jline._
import scala.tools.jline.console.completer._ import scala.tools.jline.console.completer._
import Completion._ import Completion._
import collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import spark.Logging
// REPL completor - queries supplied interpreter for valid // REPL completor - queries supplied interpreter for valid
// completions based on current contents of buffer. // completions based on current contents of buffer.
class SparkJLineCompletion(val intp: SparkIMain) extends Completion with CompletionOutput { class SparkJLineCompletion(val intp: SparkIMain) extends Completion with CompletionOutput with Logging {
val global: intp.global.type = intp.global val global: intp.global.type = intp.global
import global._ import global._
import definitions.{ PredefModule, RootClass, AnyClass, AnyRefClass, ScalaPackage, JavaLangPackage } import definitions.{ PredefModule, AnyClass, AnyRefClass, ScalaPackage, JavaLangPackage }
import rootMirror.{ RootClass, getModuleIfDefined }
type ExecResult = Any type ExecResult = Any
import intp.{ DBG, debugging, afterTyper } import intp.{ debugging }
// verbosity goes up with consecutive tabs // verbosity goes up with consecutive tabs
private var verbosity: Int = 0 private var verbosity: Int = 0
def resetVerbosity() = verbosity = 0 def resetVerbosity() = verbosity = 0
def getType(name: String, isModule: Boolean) = { def getSymbol(name: String, isModule: Boolean) = (
val f = if (isModule) definitions.getModule(_: Name) else definitions.getClass(_: Name) if (isModule) getModuleIfDefined(name)
try Some(f(name).tpe) else getModuleIfDefined(name)
catch { case _: MissingRequirementError => None } )
} def getType(name: String, isModule: Boolean) = getSymbol(name, isModule).tpe
def typeOf(name: String) = getType(name, false)
def typeOf(name: String) = getType(name, false) def moduleOf(name: String) = getType(name, true)
def moduleOf(name: String) = getType(name, true)
trait CompilerCompletion { trait CompilerCompletion {
def tp: Type def tp: Type
@ -48,16 +51,16 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
private def anyMembers = AnyClass.tpe.nonPrivateMembers private def anyMembers = AnyClass.tpe.nonPrivateMembers
def anyRefMethodsToShow = Set("isInstanceOf", "asInstanceOf", "toString") def anyRefMethodsToShow = Set("isInstanceOf", "asInstanceOf", "toString")
def tos(sym: Symbol) = sym.name.decode.toString def tos(sym: Symbol): String = sym.decodedName
def memberNamed(s: String) = members find (x => tos(x) == s) def memberNamed(s: String) = afterTyper(effectiveTp member newTermName(s))
def hasMethod(s: String) = methods exists (x => tos(x) == s) def hasMethod(s: String) = memberNamed(s).isMethod
// XXX we'd like to say "filterNot (_.isDeprecated)" but this causes the // XXX we'd like to say "filterNot (_.isDeprecated)" but this causes the
// compiler to crash for reasons not yet known. // compiler to crash for reasons not yet known.
def members = afterTyper((effectiveTp.nonPrivateMembers ++ anyMembers) filter (_.isPublic)) def members = afterTyper((effectiveTp.nonPrivateMembers.toList ++ anyMembers) filter (_.isPublic))
def methods = members filter (_.isMethod) def methods = members.toList filter (_.isMethod)
def packages = members filter (_.isPackage) def packages = members.toList filter (_.isPackage)
def aliases = members filter (_.isAliasType) def aliases = members.toList filter (_.isAliasType)
def memberNames = members map tos def memberNames = members map tos
def methodNames = methods map tos def methodNames = methods map tos
@ -65,6 +68,13 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
def aliasNames = aliases map tos def aliasNames = aliases map tos
} }
object NoTypeCompletion extends TypeMemberCompletion(NoType) {
override def memberNamed(s: String) = NoSymbol
override def members = Nil
override def follow(s: String) = None
override def alternativesFor(id: String) = Nil
}
object TypeMemberCompletion { object TypeMemberCompletion {
def apply(tp: Type, runtimeType: Type, param: NamedParam): TypeMemberCompletion = { def apply(tp: Type, runtimeType: Type, param: NamedParam): TypeMemberCompletion = {
new TypeMemberCompletion(tp) { new TypeMemberCompletion(tp) {
@ -92,7 +102,8 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
} }
} }
def apply(tp: Type): TypeMemberCompletion = { def apply(tp: Type): TypeMemberCompletion = {
if (tp.typeSymbol.isPackageClass) new PackageCompletion(tp) if (tp eq NoType) NoTypeCompletion
else if (tp.typeSymbol.isPackageClass) new PackageCompletion(tp)
else new TypeMemberCompletion(tp) else new TypeMemberCompletion(tp)
} }
def imported(tp: Type) = new ImportCompletion(tp) def imported(tp: Type) = new ImportCompletion(tp)
@ -105,7 +116,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
def excludeNames: List[String] = (anyref.methodNames filterNot anyRefMethodsToShow) :+ "_root_" def excludeNames: List[String] = (anyref.methodNames filterNot anyRefMethodsToShow) :+ "_root_"
def methodSignatureString(sym: Symbol) = { def methodSignatureString(sym: Symbol) = {
SparkIMain stripString afterTyper(new MethodSymbolOutput(sym).methodString()) IMain stripString afterTyper(new MethodSymbolOutput(sym).methodString())
} }
def exclude(name: String): Boolean = ( def exclude(name: String): Boolean = (
@ -120,7 +131,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
debugging(tp + " completions ==> ")(filtered(memberNames)) debugging(tp + " completions ==> ")(filtered(memberNames))
override def follow(s: String): Option[CompletionAware] = override def follow(s: String): Option[CompletionAware] =
debugging(tp + " -> '" + s + "' ==> ")(memberNamed(s) map (x => TypeMemberCompletion(x.tpe))) debugging(tp + " -> '" + s + "' ==> ")(Some(TypeMemberCompletion(memberNamed(s).tpe)) filterNot (_ eq NoTypeCompletion))
override def alternativesFor(id: String): List[String] = override def alternativesFor(id: String): List[String] =
debugging(id + " alternatives ==> ") { debugging(id + " alternatives ==> ") {
@ -157,28 +168,29 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
object ids extends CompletionAware { object ids extends CompletionAware {
override def completions(verbosity: Int) = intp.unqualifiedIds ++ List("classOf") //, "_root_") override def completions(verbosity: Int) = intp.unqualifiedIds ++ List("classOf") //, "_root_")
// now we use the compiler for everything. // now we use the compiler for everything.
override def follow(id: String) = { override def follow(id: String): Option[CompletionAware] = {
if (completions(0) contains id) { if (!completions(0).contains(id))
intp typeOfExpression id map { tpe => return None
def default = TypeMemberCompletion(tpe)
// only rebinding vals in power mode for now. val tpe = intp typeOfExpression id
if (!isReplPower) default if (tpe == NoType)
else intp runtimeClassAndTypeOfTerm id match { return None
case Some((clazz, runtimeType)) =>
val sym = intp.symbolOfTerm(id) def default = Some(TypeMemberCompletion(tpe))
if (sym.isStable) {
val param = new NamedParam.Untyped(id, intp valueOfTerm id getOrElse null) // only rebinding vals in power mode for now.
TypeMemberCompletion(tpe, runtimeType, param) if (!isReplPower) default
} else intp runtimeClassAndTypeOfTerm id match {
else default case Some((clazz, runtimeType)) =>
case _ => val sym = intp.symbolOfTerm(id)
default if (sym.isStable) {
val param = new NamedParam.Untyped(id, intp valueOfTerm id getOrElse null)
Some(TypeMemberCompletion(tpe, runtimeType, param))
} }
} else default
case _ =>
default
} }
else
None
} }
override def toString = "<repl ids> (%s)".format(completions(0).size) override def toString = "<repl ids> (%s)".format(completions(0).size)
} }
@ -188,14 +200,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
// literal Ints, Strings, etc. // literal Ints, Strings, etc.
object literals extends CompletionAware { object literals extends CompletionAware {
def simpleParse(code: String): Tree = { def simpleParse(code: String): Tree = newUnitParser(code).templateStats().last
val unit = new CompilationUnit(new util.BatchSourceFile("<console>", code))
val scanner = new syntaxAnalyzer.UnitParser(unit)
val tss = scanner.templateStatSeq(false)._2
if (tss.size == 1) tss.head else EmptyTree
}
def completions(verbosity: Int) = Nil def completions(verbosity: Int) = Nil
override def follow(id: String) = simpleParse(id) match { override def follow(id: String) = simpleParse(id) match {
@ -280,19 +285,6 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
if (parsed.isEmpty) xs map ("." + _) else xs if (parsed.isEmpty) xs map ("." + _) else xs
} }
// chasing down results which won't parse
def execute(line: String): Option[ExecResult] = {
val parsed = Parsed(line)
def noDotOrSlash = line forall (ch => ch != '.' && ch != '/')
if (noDotOrSlash) None // we defer all unqualified ids to the repl.
else {
(ids executionFor parsed) orElse
(rootClass executionFor parsed) orElse
(FileCompletion executionFor line)
}
}
// generic interface for querying (e.g. interpreter loop, testing) // generic interface for querying (e.g. interpreter loop, testing)
def completions(buf: String): List[String] = def completions(buf: String): List[String] =
topLevelFor(Parsed.dotted(buf + ".", buf.length + 1)) topLevelFor(Parsed.dotted(buf + ".", buf.length + 1))
@ -327,7 +319,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
// This is jline's entry point for completion. // This is jline's entry point for completion.
override def complete(buf: String, cursor: Int): Candidates = { override def complete(buf: String, cursor: Int): Candidates = {
verbosity = if (isConsecutiveTabs(buf, cursor)) verbosity + 1 else 0 verbosity = if (isConsecutiveTabs(buf, cursor)) verbosity + 1 else 0
DBG("\ncomplete(%s, %d) last = (%s, %d), verbosity: %s".format(buf, cursor, lastBuf, lastCursor, verbosity)) logDebug("\ncomplete(%s, %d) last = (%s, %d), verbosity: %s".format(buf, cursor, lastBuf, lastCursor, verbosity))
// we don't try lower priority completions unless higher ones return no results. // we don't try lower priority completions unless higher ones return no results.
def tryCompletion(p: Parsed, completionFunction: Parsed => List[String]): Option[Candidates] = { def tryCompletion(p: Parsed, completionFunction: Parsed => List[String]): Option[Candidates] = {
@ -340,7 +332,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
val advance = commonPrefix(winners) val advance = commonPrefix(winners)
lastCursor = p.position + advance.length lastCursor = p.position + advance.length
lastBuf = (buf take p.position) + advance lastBuf = (buf take p.position) + advance
DBG("tryCompletion(%s, _) lastBuf = %s, lastCursor = %s, p.position = %s".format( logDebug("tryCompletion(%s, _) lastBuf = %s, lastCursor = %s, p.position = %s".format(
p, lastBuf, lastCursor, p.position)) p, lastBuf, lastCursor, p.position))
p.position p.position
} }
@ -356,23 +348,29 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
if (!looksLikeInvocation(buf)) None if (!looksLikeInvocation(buf)) None
else tryCompletion(Parsed.dotted(buf drop 1, cursor), lastResultFor) else tryCompletion(Parsed.dotted(buf drop 1, cursor), lastResultFor)
def regularCompletion = tryCompletion(mkDotted, topLevelFor) def tryAll = (
def fileCompletion = lastResultCompletion
if (!looksLikePath(buf)) None orElse tryCompletion(mkDotted, topLevelFor)
else tryCompletion(mkUndelimited, FileCompletion completionsFor _.buffer) getOrElse Candidates(cursor, Nil)
)
/** This is the kickoff point for all manner of theoretically possible compiler /**
* unhappiness - fault may be here or elsewhere, but we don't want to crash the * This is the kickoff point for all manner of theoretically
* repl regardless. Hopefully catching Exception is enough, but because the * possible compiler unhappiness. The fault may be here or
* compiler still throws some Errors it may not be. * elsewhere, but we don't want to crash the repl regardless.
* The compiler makes it impossible to avoid catching Throwable
* with its unfortunate tendency to throw java.lang.Errors and
* AssertionErrors as the hats drop. We take two swings at it
* because there are some spots which like to throw an assertion
* once, then work after that. Yeah, what can I say.
*/ */
try { try tryAll
(lastResultCompletion orElse regularCompletion orElse fileCompletion) getOrElse Candidates(cursor, Nil) catch { case ex: Throwable =>
} logWarning("Error: complete(%s, %s) provoked".format(buf, cursor) + ex)
catch { Candidates(cursor,
case ex: Exception => if (isReplDebug) List("<error:" + ex + ">")
DBG("Error: complete(%s, %s) provoked %s".format(buf, cursor, ex)) else Nil
Candidates(cursor, List(" ", "<completion error: " + ex.getMessage + ">")) )
} }
} }
} }

View file

@ -1,5 +1,5 @@
/* NSC -- new Scala compiler /* NSC -- new Scala compiler
* Copyright 2005-2011 LAMP/EPFL * Copyright 2005-2013 LAMP/EPFL
* @author Stepan Koltsov * @author Stepan Koltsov
*/ */
@ -15,13 +15,15 @@ import scala.collection.JavaConverters._
import Completion._ import Completion._
import io.Streamable.slurp import io.Streamable.slurp
/** Reads from the console using JLine */ /**
class SparkJLineReader(val completion: Completion) extends InteractiveReader { * Reads from the console using JLine.
*/
class SparkJLineReader(_completion: => Completion) extends InteractiveReader {
val interactive = true val interactive = true
val consoleReader = new JLineConsoleReader()
lazy val completion = _completion
lazy val history: JLineHistory = JLineHistory() lazy val history: JLineHistory = JLineHistory()
lazy val keyBindings =
try KeyBinding parse slurp(term.getDefaultBindings)
catch { case _: Exception => Nil }
private def term = consoleReader.getTerminal() private def term = consoleReader.getTerminal()
def reset() = term.reset() def reset() = term.reset()
@ -37,6 +39,9 @@ class SparkJLineReader(val completion: Completion) extends InteractiveReader {
} }
class JLineConsoleReader extends ConsoleReader with ConsoleReaderHelper { class JLineConsoleReader extends ConsoleReader with ConsoleReaderHelper {
if ((history: History) ne NoHistory)
this setHistory history
// working around protected/trait/java insufficiencies. // working around protected/trait/java insufficiencies.
def goBack(num: Int): Unit = back(num) def goBack(num: Int): Unit = back(num)
def readOneKey(prompt: String) = { def readOneKey(prompt: String) = {
@ -46,34 +51,28 @@ class SparkJLineReader(val completion: Completion) extends InteractiveReader {
} }
def eraseLine() = consoleReader.resetPromptLine("", "", 0) def eraseLine() = consoleReader.resetPromptLine("", "", 0)
def redrawLineAndFlush(): Unit = { flush() ; drawLine() ; flush() } def redrawLineAndFlush(): Unit = { flush() ; drawLine() ; flush() }
// override def readLine(prompt: String): String
this setBellEnabled false // A hook for running code after the repl is done initializing.
if (history ne NoHistory) lazy val postInit: Unit = {
this setHistory history this setBellEnabled false
if (completion ne NoCompletion) { if (completion ne NoCompletion) {
val argCompletor: ArgumentCompleter = val argCompletor: ArgumentCompleter =
new ArgumentCompleter(new JLineDelimiter, scalaToJline(completion.completer())) new ArgumentCompleter(new JLineDelimiter, scalaToJline(completion.completer()))
argCompletor setStrict false argCompletor setStrict false
this addCompleter argCompletor this addCompleter argCompletor
this setAutoprintThreshold 400 // max completion candidates without warning this setAutoprintThreshold 400 // max completion candidates without warning
}
} }
} }
val consoleReader: JLineConsoleReader = new JLineConsoleReader() def currentLine = consoleReader.getCursorBuffer.buffer.toString
def currentLine: String = consoleReader.getCursorBuffer.buffer.toString
def redrawLine() = consoleReader.redrawLineAndFlush() def redrawLine() = consoleReader.redrawLineAndFlush()
def eraseLine() = { def eraseLine() = consoleReader.eraseLine()
while (consoleReader.delete()) { } // Alternate implementation, not sure if/when I need this.
// consoleReader.eraseLine() // def eraseLine() = while (consoleReader.delete()) { }
}
def readOneLine(prompt: String) = consoleReader readLine prompt def readOneLine(prompt: String) = consoleReader readLine prompt
def readOneKey(prompt: String) = consoleReader readOneKey prompt def readOneKey(prompt: String) = consoleReader readOneKey prompt
} }
object SparkJLineReader {
def apply(intp: SparkIMain): SparkJLineReader = apply(new SparkJLineCompletion(intp))
def apply(comp: Completion): SparkJLineReader = new SparkJLineReader(comp)
}

View file

@ -1,22 +1,24 @@
/* NSC -- new Scala compiler /* NSC -- new Scala compiler
* Copyright 2005-2011 LAMP/EPFL * Copyright 2005-2013 LAMP/EPFL
* @author Martin Odersky * @author Martin Odersky
*/ */
package spark.repl package spark
package repl
import scala.tools.nsc._ import scala.tools.nsc._
import scala.tools.nsc.interpreter._ import scala.tools.nsc.interpreter._
import scala.collection.{ mutable, immutable } import scala.collection.{ mutable, immutable }
import scala.PartialFunction.cond import scala.PartialFunction.cond
import scala.reflect.NameTransformer import scala.reflect.internal.Chars
import util.Chars import scala.reflect.internal.Flags._
import scala.language.implicitConversions
trait SparkMemberHandlers { trait SparkMemberHandlers {
val intp: SparkIMain val intp: SparkIMain
import intp.{ Request, global, naming, atPickler } import intp.{ Request, global, naming }
import global._ import global._
import naming._ import naming._
@ -54,26 +56,28 @@ trait SparkMemberHandlers {
} }
def chooseHandler(member: Tree): MemberHandler = member match { def chooseHandler(member: Tree): MemberHandler = member match {
case member: DefDef => new DefHandler(member) case member: DefDef => new DefHandler(member)
case member: ValDef => new ValHandler(member) case member: ValDef => new ValHandler(member)
case member@Assign(Ident(_), _) => new AssignHandler(member) case member: Assign => new AssignHandler(member)
case member: ModuleDef => new ModuleHandler(member) case member: ModuleDef => new ModuleHandler(member)
case member: ClassDef => new ClassHandler(member) case member: ClassDef => new ClassHandler(member)
case member: TypeDef => new TypeAliasHandler(member) case member: TypeDef => new TypeAliasHandler(member)
case member: Import => new ImportHandler(member) case member: Import => new ImportHandler(member)
case DocDef(_, documented) => chooseHandler(documented) case DocDef(_, documented) => chooseHandler(documented)
case member => new GenericHandler(member) case member => new GenericHandler(member)
} }
sealed abstract class MemberDefHandler(override val member: MemberDef) extends MemberHandler(member) { sealed abstract class MemberDefHandler(override val member: MemberDef) extends MemberHandler(member) {
def symbol = if (member.symbol eq null) NoSymbol else member.symbol
def name: Name = member.name def name: Name = member.name
def mods: Modifiers = member.mods def mods: Modifiers = member.mods
def keyword = member.keyword def keyword = member.keyword
def prettyName = NameTransformer.decode(name) def prettyName = name.decode
override def definesImplicit = member.mods.isImplicit override def definesImplicit = member.mods.isImplicit
override def definesTerm: Option[TermName] = Some(name.toTermName) filter (_ => name.isTermName) override def definesTerm: Option[TermName] = Some(name.toTermName) filter (_ => name.isTermName)
override def definesType: Option[TypeName] = Some(name.toTypeName) filter (_ => name.isTypeName) override def definesType: Option[TypeName] = Some(name.toTypeName) filter (_ => name.isTypeName)
override def definedSymbols = if (symbol eq NoSymbol) Nil else List(symbol)
} }
/** Class to handle one member among all the members included /** Class to handle one member among all the members included
@ -82,10 +86,7 @@ trait SparkMemberHandlers {
sealed abstract class MemberHandler(val member: Tree) { sealed abstract class MemberHandler(val member: Tree) {
def definesImplicit = false def definesImplicit = false
def definesValue = false def definesValue = false
def isLegalTopLevel = member match { def isLegalTopLevel = false
case _: ModuleDef | _: ClassDef | _: Import => true
case _ => false
}
def definesTerm = Option.empty[TermName] def definesTerm = Option.empty[TermName]
def definesType = Option.empty[TypeName] def definesType = Option.empty[TypeName]
@ -94,6 +95,7 @@ trait SparkMemberHandlers {
def importedNames = List[Name]() def importedNames = List[Name]()
def definedNames = definesTerm.toList ++ definesType.toList def definedNames = definesTerm.toList ++ definesType.toList
def definedOrImported = definedNames ++ importedNames def definedOrImported = definedNames ++ importedNames
def definedSymbols = List[Symbol]()
def extraCodeToEvaluate(req: Request): String = "" def extraCodeToEvaluate(req: Request): String = ""
def resultExtractionCode(req: Request): String = "" def resultExtractionCode(req: Request): String = ""
@ -117,21 +119,26 @@ trait SparkMemberHandlers {
if (mods.isLazy) codegenln(false, "<lazy>") if (mods.isLazy) codegenln(false, "<lazy>")
else any2stringOf(req fullPath name, maxStringElements) else any2stringOf(req fullPath name, maxStringElements)
""" + "%s: %s = " + %s""".format(prettyName, string2code(req typeOf name), resultString) val vidString =
if (replProps.vids) """" + " @ " + "%%8x".format(System.identityHashCode(%s)) + " """.trim.format(req fullPath name)
else ""
""" + "%s%s: %s = " + %s""".format(string2code(prettyName), vidString, string2code(req typeOf name), resultString)
} }
} }
} }
class DefHandler(member: DefDef) extends MemberDefHandler(member) { class DefHandler(member: DefDef) extends MemberDefHandler(member) {
private def vparamss = member.vparamss private def vparamss = member.vparamss
// true if 0-arity private def isMacro = member.symbol hasFlag MACRO
override def definesValue = vparamss.isEmpty || vparamss.head.isEmpty // true if not a macro and 0-arity
override def definesValue = !isMacro && flattensToEmpty(vparamss)
override def resultExtractionCode(req: Request) = override def resultExtractionCode(req: Request) =
if (mods.isPublic) codegenln(name, ": ", req.typeOf(name)) else "" if (mods.isPublic) codegenln(name, ": ", req.typeOf(name)) else ""
} }
class AssignHandler(member: Assign) extends MemberHandler(member) { class AssignHandler(member: Assign) extends MemberHandler(member) {
val lhs = member.lhs.asInstanceOf[Ident] // an unfortunate limitation val Assign(lhs, rhs) = member
val name = newTermName(freshInternalVarName()) val name = newTermName(freshInternalVarName())
override def definesTerm = Some(name) override def definesTerm = Some(name)
@ -142,15 +149,15 @@ trait SparkMemberHandlers {
/** Print out lhs instead of the generated varName */ /** Print out lhs instead of the generated varName */
override def resultExtractionCode(req: Request) = { override def resultExtractionCode(req: Request) = {
val lhsType = string2code(req lookupTypeOf name) val lhsType = string2code(req lookupTypeOf name)
val res = string2code(req fullPath name) val res = string2code(req fullPath name)
""" + "%s: %s = " + %s + "\n" """.format(string2code(lhs.toString), lhsType, res) + "\n"
""" + "%s: %s = " + %s + "\n" """.format(lhs, lhsType, res) + "\n"
} }
} }
class ModuleHandler(module: ModuleDef) extends MemberDefHandler(module) { class ModuleHandler(module: ModuleDef) extends MemberDefHandler(module) {
override def definesTerm = Some(name) override def definesTerm = Some(name)
override def definesValue = true override def definesValue = true
override def isLegalTopLevel = true
override def resultExtractionCode(req: Request) = codegenln("defined module ", name) override def resultExtractionCode(req: Request) = codegenln("defined module ", name)
} }
@ -158,6 +165,7 @@ trait SparkMemberHandlers {
class ClassHandler(member: ClassDef) extends MemberDefHandler(member) { class ClassHandler(member: ClassDef) extends MemberDefHandler(member) {
override def definesType = Some(name.toTypeName) override def definesType = Some(name.toTypeName)
override def definesTerm = Some(name.toTermName) filter (_ => mods.isCase) override def definesTerm = Some(name.toTermName) filter (_ => mods.isCase)
override def isLegalTopLevel = true
override def resultExtractionCode(req: Request) = override def resultExtractionCode(req: Request) =
codegenln("defined %s %s".format(keyword, name)) codegenln("defined %s %s".format(keyword, name))
@ -173,7 +181,20 @@ trait SparkMemberHandlers {
class ImportHandler(imp: Import) extends MemberHandler(imp) { class ImportHandler(imp: Import) extends MemberHandler(imp) {
val Import(expr, selectors) = imp val Import(expr, selectors) = imp
def targetType = intp.typeOfExpression("" + expr) def targetType: Type = intp.typeOfExpression("" + expr)
override def isLegalTopLevel = true
def createImportForName(name: Name): String = {
selectors foreach {
case sel @ ImportSelector(old, _, `name`, _) => return "import %s.{ %s }".format(expr, sel)
case _ => ()
}
"import %s.%s".format(expr, name)
}
// TODO: Need to track these specially to honor Predef masking attempts,
// because they must be the leading imports in the code generated for each
// line. We can use the same machinery as Contexts now, anyway.
def isPredefImport = isReferenceToPredef(expr)
// wildcard imports, e.g. import foo._ // wildcard imports, e.g. import foo._
private def selectorWild = selectors filter (_.name == nme.USCOREkw) private def selectorWild = selectors filter (_.name == nme.USCOREkw)
@ -183,14 +204,17 @@ trait SparkMemberHandlers {
/** Whether this import includes a wildcard import */ /** Whether this import includes a wildcard import */
val importsWildcard = selectorWild.nonEmpty val importsWildcard = selectorWild.nonEmpty
/** Whether anything imported is implicit .*/
def importsImplicit = implicitSymbols.nonEmpty
def implicitSymbols = importedSymbols filter (_.isImplicit) def implicitSymbols = importedSymbols filter (_.isImplicit)
def importedSymbols = individualSymbols ++ wildcardSymbols def importedSymbols = individualSymbols ++ wildcardSymbols
lazy val individualSymbols: List[Symbol] = lazy val individualSymbols: List[Symbol] =
atPickler(targetType.toList flatMap (tp => individualNames map (tp nonPrivateMember _))) beforePickler(individualNames map (targetType nonPrivateMember _))
lazy val wildcardSymbols: List[Symbol] = lazy val wildcardSymbols: List[Symbol] =
if (importsWildcard) atPickler(targetType.toList flatMap (_.nonPrivateMembers)) if (importsWildcard) beforePickler(targetType.nonPrivateMembers.toList)
else Nil else Nil
/** Complete list of names imported by a wildcard */ /** Complete list of names imported by a wildcard */

View file

@ -1,51 +1,14 @@
package spark.repl package spark.repl
import java.io._ import java.io.FileWriter
import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import org.scalatest.FunSuite import org.scalatest.FunSuite
import com.google.common.io.Files import com.google.common.io.Files
class ReplSuite extends FunSuite { class ReplSuite extends FunSuite with ReplSuiteMixin {
def runInterpreter(master: String, input: String): String = {
val in = new BufferedReader(new StringReader(input + "\n"))
val out = new StringWriter()
val cl = getClass.getClassLoader
var paths = new ArrayBuffer[String]
if (cl.isInstanceOf[URLClassLoader]) {
val urlLoader = cl.asInstanceOf[URLClassLoader]
for (url <- urlLoader.getURLs) {
if (url.getProtocol == "file") {
paths += url.getFile
}
}
}
val interp = new SparkILoop(in, new PrintWriter(out), master)
spark.repl.Main.interp = interp
val separator = System.getProperty("path.separator")
interp.process(Array("-classpath", paths.mkString(separator)))
spark.repl.Main.interp = null
if (interp.sparkContext != null)
interp.sparkContext.stop()
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.master.port")
return out.toString
}
def assertContains(message: String, output: String) { test("simple foreach with accumulator") {
assert(output contains message,
"Interpreter output did not contain '" + message + "':\n" + output)
}
def assertDoesNotContain(message: String, output: String) {
assert(!(output contains message),
"Interpreter output contained '" + message + "':\n" + output)
}
test ("simple foreach with accumulator") {
val output = runInterpreter("local", """ val output = runInterpreter("local", """
val accum = sc.accumulator(0) val accum = sc.accumulator(0)
sc.parallelize(1 to 10).foreach(x => accum += x) sc.parallelize(1 to 10).foreach(x => accum += x)
@ -56,7 +19,7 @@ class ReplSuite extends FunSuite {
assertContains("res1: Int = 55", output) assertContains("res1: Int = 55", output)
} }
test ("external vars") { test("external vars") {
val output = runInterpreter("local", """ val output = runInterpreter("local", """
var v = 7 var v = 7
sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
@ -69,7 +32,7 @@ class ReplSuite extends FunSuite {
assertContains("res1: Int = 100", output) assertContains("res1: Int = 100", output)
} }
test ("external classes") { test("external classes") {
val output = runInterpreter("local", """ val output = runInterpreter("local", """
class C { class C {
def foo = 5 def foo = 5
@ -81,7 +44,7 @@ class ReplSuite extends FunSuite {
assertContains("res0: Int = 50", output) assertContains("res0: Int = 50", output)
} }
test ("external functions") { test("external functions") {
val output = runInterpreter("local", """ val output = runInterpreter("local", """
def double(x: Int) = x + x def double(x: Int) = x + x
sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_)
@ -91,7 +54,7 @@ class ReplSuite extends FunSuite {
assertContains("res0: Int = 110", output) assertContains("res0: Int = 110", output)
} }
test ("external functions that access vars") { test("external functions that access vars") {
val output = runInterpreter("local", """ val output = runInterpreter("local", """
var v = 7 var v = 7
def getV() = v def getV() = v
@ -105,7 +68,7 @@ class ReplSuite extends FunSuite {
assertContains("res1: Int = 100", output) assertContains("res1: Int = 100", output)
} }
test ("broadcast vars") { test("broadcast vars") {
// Test that the value that a broadcast var had when it was created is used, // Test that the value that a broadcast var had when it was created is used,
// even if that variable is then modified in the driver program // even if that variable is then modified in the driver program
// TODO: This doesn't actually work for arrays when we run in local mode! // TODO: This doesn't actually work for arrays when we run in local mode!
@ -122,7 +85,7 @@ class ReplSuite extends FunSuite {
assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
} }
test ("interacting with files") { test("interacting with files") {
val tempDir = Files.createTempDir() val tempDir = Files.createTempDir()
val out = new FileWriter(tempDir + "/input") val out = new FileWriter(tempDir + "/input")
out.write("Hello world!\n") out.write("Hello world!\n")
@ -143,7 +106,7 @@ class ReplSuite extends FunSuite {
} }
if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { if (System.getenv("MESOS_NATIVE_LIBRARY") != null) {
test ("running on Mesos") { test("running on Mesos") {
val output = runInterpreter("localquiet", """ val output = runInterpreter("localquiet", """
var v = 7 var v = 7
def getV() = v def getV() = v
@ -164,4 +127,5 @@ class ReplSuite extends FunSuite {
assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
} }
} }
} }

View file

@ -0,0 +1,60 @@
package spark.repl
import java.io.BufferedReader
import java.io.PrintWriter
import java.io.StringReader
import java.io.StringWriter
import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.future
import spark.deploy.master.Master
import spark.deploy.worker.Worker
trait ReplSuiteMixin {
val localIp = "127.0.1.2"
val port = "7089"
val sparkUrl = s"spark://$localIp:$port"
def setupStandaloneCluster() {
future { Master.main(Array("-i", localIp, "-p", port, "--webui-port", "0")) }
Thread.sleep(2000)
future { Worker.main(Array(sparkUrl, "--webui-port", "0")) }
}
def runInterpreter(master: String, input: String): String = {
val in = new BufferedReader(new StringReader(input + "\n"))
val out = new StringWriter()
val cl = getClass.getClassLoader
var paths = new ArrayBuffer[String]
if (cl.isInstanceOf[URLClassLoader]) {
val urlLoader = cl.asInstanceOf[URLClassLoader]
for (url <- urlLoader.getURLs) {
if (url.getProtocol == "file") {
paths += url.getFile
}
}
}
val interp = new SparkILoop(in, new PrintWriter(out), master)
spark.repl.Main.interp = interp
val separator = System.getProperty("path.separator")
interp.process(Array("-classpath", paths.mkString(separator)))
if (interp != null)
interp.closeInterpreter();
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.master.port")
return out.toString
}
def assertContains(message: String, output: String) {
assert(output contains message,
"Interpreter output did not contain '" + message + "':\n" + output)
}
def assertDoesNotContain(message: String, output: String) {
assert(!(output contains message),
"Interpreter output contained '" + message + "':\n" + output)
}
}

View file

@ -0,0 +1,103 @@
package spark.repl
import java.io.FileWriter
import org.scalatest.FunSuite
import com.google.common.io.Files
class StandaloneClusterReplSuite extends FunSuite with ReplSuiteMixin {
setupStandaloneCluster
test("simple collect") {
val output = runInterpreter(sparkUrl, """
var x = 123
val data = sc.parallelize(1 to 3).map(_ + x)
data.take(3)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("124", output)
assertContains("125", output)
assertContains("126", output)
}
test("simple foreach with accumulator") {
val output = runInterpreter(sparkUrl, """
val accum = sc.accumulator(0)
sc.parallelize(1 to 10).foreach(x => accum += x)
accum.value
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res1: Int = 55", output)
}
test("external vars") {
val output = runInterpreter(sparkUrl, """
var v = 7
sc.parallelize(1 to 10).map(x => v).take(10).reduceLeft(_+_)
v = 10
sc.parallelize(1 to 10).map(x => v).take(10).reduceLeft(_+_)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}
test("external classes") {
val output = runInterpreter(sparkUrl, """
class C {
def foo = 5
}
sc.parallelize(1 to 10).map(x => (new C).foo).take(10).reduceLeft(_+_)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 50", output)
}
test("external functions") {
val output = runInterpreter(sparkUrl, """
def double(x: Int) = x + x
sc.parallelize(1 to 10).map(x => double(x)).take(10).reduceLeft(_+_)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 110", output)
}
test("external functions that access vars") {
val output = runInterpreter(sparkUrl, """
var v = 7
def getV() = v
sc.parallelize(1 to 10).map(x => getV()).take(10).reduceLeft(_+_)
v = 10
sc.parallelize(1 to 10).map(x => getV()).take(10).reduceLeft(_+_)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}
test("broadcast vars") {
// Test that the value that a broadcast var had when it was created is used,
// even if that variable is then modified in the driver program
val output = runInterpreter(sparkUrl, """
var array = new Array[Int](5)
val broadcastArray = sc.broadcast(array)
sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).take(5)
array(0) = 5
sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).take(5)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output)
assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
}
}

2
run
View file

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
SCALA_VERSION=2.9.2 SCALA_VERSION=2.10
# Figure out where the Scala framework is installed # Figure out where the Scala framework is installed
FWDIR="$(cd `dirname $0`; pwd)" FWDIR="$(cd `dirname $0`; pwd)"