[SPARK-22393][SPARK-SHELL] spark-shell can't find imported types in class constructors, extends clause
## What changes were proposed in this pull request? [SPARK-22393](https://issues.apache.org/jira/browse/SPARK-22393) ## How was this patch tested? With a new test case in `RepSuite` ---- This code is a retrofit of the Scala [SI-9881](https://github.com/scala/bug/issues/9881) bug fix, which never made it into the Scala 2.11 branches. Pushing these changes directly to the Scala repo is not practical (see: https://github.com/scala/scala/pull/6195). Author: Mark Petruska <petruska.mark@gmail.com> Closes #19846 from mpetruska/SPARK-22393.
This commit is contained in:
parent
16adaf634b
commit
9d06a9e0cf
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.repl
|
||||
|
||||
import scala.tools.nsc.interpreter.{ExprTyper, IR}
|
||||
|
||||
trait SparkExprTyper extends ExprTyper {
|
||||
|
||||
import repl._
|
||||
import global.{reporter => _, Import => _, _}
|
||||
import naming.freshInternalVarName
|
||||
|
||||
def doInterpret(code: String): IR.Result = {
|
||||
// interpret/interpretSynthetic may change the phase,
|
||||
// which would have unintended effects on types.
|
||||
val savedPhase = phase
|
||||
try interpretSynthetic(code) finally phase = savedPhase
|
||||
}
|
||||
|
||||
override def symbolOfLine(code: String): Symbol = {
|
||||
def asExpr(): Symbol = {
|
||||
val name = freshInternalVarName()
|
||||
// Typing it with a lazy val would give us the right type, but runs
|
||||
// into compiler bugs with things like existentials, so we compile it
|
||||
// behind a def and strip the NullaryMethodType which wraps the expr.
|
||||
val line = "def " + name + " = " + code
|
||||
|
||||
doInterpret(line) match {
|
||||
case IR.Success =>
|
||||
val sym0 = symbolOfTerm(name)
|
||||
// drop NullaryMethodType
|
||||
sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
|
||||
case _ => NoSymbol
|
||||
}
|
||||
}
|
||||
|
||||
def asDefn(): Symbol = {
|
||||
val old = repl.definedSymbolList.toSet
|
||||
|
||||
doInterpret(code) match {
|
||||
case IR.Success =>
|
||||
repl.definedSymbolList filterNot old match {
|
||||
case Nil => NoSymbol
|
||||
case sym :: Nil => sym
|
||||
case syms => NoSymbol.newOverloaded(NoPrefix, syms)
|
||||
}
|
||||
case _ => NoSymbol
|
||||
}
|
||||
}
|
||||
|
||||
def asError(): Symbol = {
|
||||
doInterpret(code)
|
||||
NoSymbol
|
||||
}
|
||||
|
||||
beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
|
||||
}
|
||||
|
||||
}
|
|
@ -35,6 +35,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
|
|||
def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
|
||||
def this() = this(None, new JPrintWriter(Console.out, true))
|
||||
|
||||
override def createInterpreter(): Unit = {
|
||||
intp = new SparkILoopInterpreter(settings, out)
|
||||
}
|
||||
|
||||
val initializationCommands: Seq[String] = Seq(
|
||||
"""
|
||||
@transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) {
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.repl
|
||||
|
||||
import scala.tools.nsc.Settings
|
||||
import scala.tools.nsc.interpreter._
|
||||
|
||||
class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) {
|
||||
self =>
|
||||
|
||||
override lazy val memberHandlers = new {
|
||||
val intp: self.type = self
|
||||
} with MemberHandlers {
|
||||
import intp.global._
|
||||
|
||||
override def chooseHandler(member: intp.global.Tree): MemberHandler = member match {
|
||||
case member: Import => new SparkImportHandler(member)
|
||||
case _ => super.chooseHandler (member)
|
||||
}
|
||||
|
||||
class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) {
|
||||
|
||||
override def targetType: Type = intp.global.rootMirror.getModuleIfDefined("" + expr) match {
|
||||
case NoSymbol => intp.typeOfExpression("" + expr)
|
||||
case sym => sym.tpe
|
||||
}
|
||||
|
||||
private def safeIndexOf(name: Name, s: String): Int = fixIndexOf(name, pos(name, s))
|
||||
private def fixIndexOf(name: Name, idx: Int): Int = if (idx == name.length) -1 else idx
|
||||
private def pos(name: Name, s: String): Int = {
|
||||
var i = name.pos(s.charAt(0), 0)
|
||||
val sLen = s.length()
|
||||
if (sLen == 1) return i
|
||||
while (i + sLen <= name.length) {
|
||||
var j = 1
|
||||
while (s.charAt(j) == name.charAt(i + j)) {
|
||||
j += 1
|
||||
if (j == sLen) return i
|
||||
}
|
||||
i = name.pos(s.charAt(0), i + 1)
|
||||
}
|
||||
name.length
|
||||
}
|
||||
|
||||
private def isFlattenedSymbol(sym: Symbol): Boolean =
|
||||
sym.owner.isPackageClass &&
|
||||
sym.name.containsName(nme.NAME_JOIN_STRING) &&
|
||||
sym.owner.info.member(sym.name.take(
|
||||
safeIndexOf(sym.name, nme.NAME_JOIN_STRING))) != NoSymbol
|
||||
|
||||
private def importableTargetMembers =
|
||||
importableMembers(exitingTyper(targetType)).filterNot(isFlattenedSymbol).toList
|
||||
|
||||
def isIndividualImport(s: ImportSelector): Boolean =
|
||||
s.name != nme.WILDCARD && s.rename != nme.WILDCARD
|
||||
def isWildcardImport(s: ImportSelector): Boolean =
|
||||
s.name == nme.WILDCARD
|
||||
|
||||
// non-wildcard imports
|
||||
private def individualSelectors = selectors filter isIndividualImport
|
||||
|
||||
override val importsWildcard: Boolean = selectors exists isWildcardImport
|
||||
|
||||
lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = {
|
||||
val selectorRenameMap =
|
||||
individualSelectors.flatMap(x => x.name.bothNames zip x.rename.bothNames).toMap
|
||||
importableTargetMembers flatMap (m => selectorRenameMap.get(m.name) map (m -> _))
|
||||
}
|
||||
|
||||
override lazy val individualSymbols: List[Symbol] = importableSymbolsWithRenames map (_._1)
|
||||
override lazy val wildcardSymbols: List[Symbol] =
|
||||
if (importsWildcard) importableTargetMembers else Nil
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object expressionTyper extends {
|
||||
val repl: SparkILoopInterpreter.this.type = self
|
||||
} with SparkExprTyper { }
|
||||
|
||||
override def symbolOfLine(code: String): global.Symbol =
|
||||
expressionTyper.symbolOfLine(code)
|
||||
|
||||
override def typeOfExpression(expr: String, silent: Boolean): global.Type =
|
||||
expressionTyper.typeOfExpression(expr, silent)
|
||||
|
||||
}
|
|
@ -227,4 +227,14 @@ class ReplSuite extends SparkFunSuite {
|
|||
assertDoesNotContain("error: not found: value sc", output)
|
||||
}
|
||||
|
||||
test("spark-shell should find imported types in class constructors and extends clause") {
|
||||
val output = runInterpreter("local",
|
||||
"""
|
||||
|import org.apache.spark.Partition
|
||||
|class P(p: Partition)
|
||||
|class P(val index: Int) extends Partition
|
||||
""".stripMargin)
|
||||
assertDoesNotContain("error: not found: type Partition", output)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue