[SPARK-22393][SPARK-SHELL] spark-shell can't find imported types in class constructors, extends clause

## What changes were proposed in this pull request?

[SPARK-22393](https://issues.apache.org/jira/browse/SPARK-22393)

## How was this patch tested?

With a new test case in `RepSuite`

----

This code is a retrofit of the Scala [SI-9881](https://github.com/scala/bug/issues/9881) bug fix, which never made it into the Scala 2.11 branches. Pushing these changes directly to the Scala repo is not practical (see: https://github.com/scala/scala/pull/6195).

Author: Mark Petruska <petruska.mark@gmail.com>

Closes #19846 from mpetruska/SPARK-22393.
This commit is contained in:
Mark Petruska 2017-12-01 05:14:12 -06:00 committed by Sean Owen
parent 16adaf634b
commit 9d06a9e0cf
4 changed files with 191 additions and 0 deletions

View file

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.repl
import scala.tools.nsc.interpreter.{ExprTyper, IR}
trait SparkExprTyper extends ExprTyper {
import repl._
import global.{reporter => _, Import => _, _}
import naming.freshInternalVarName
def doInterpret(code: String): IR.Result = {
// interpret/interpretSynthetic may change the phase,
// which would have unintended effects on types.
val savedPhase = phase
try interpretSynthetic(code) finally phase = savedPhase
}
override def symbolOfLine(code: String): Symbol = {
def asExpr(): Symbol = {
val name = freshInternalVarName()
// Typing it with a lazy val would give us the right type, but runs
// into compiler bugs with things like existentials, so we compile it
// behind a def and strip the NullaryMethodType which wraps the expr.
val line = "def " + name + " = " + code
doInterpret(line) match {
case IR.Success =>
val sym0 = symbolOfTerm(name)
// drop NullaryMethodType
sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
case _ => NoSymbol
}
}
def asDefn(): Symbol = {
val old = repl.definedSymbolList.toSet
doInterpret(code) match {
case IR.Success =>
repl.definedSymbolList filterNot old match {
case Nil => NoSymbol
case sym :: Nil => sym
case syms => NoSymbol.newOverloaded(NoPrefix, syms)
}
case _ => NoSymbol
}
}
def asError(): Symbol = {
doInterpret(code)
NoSymbol
}
beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
}
}

View file

@ -35,6 +35,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
def this() = this(None, new JPrintWriter(Console.out, true))
override def createInterpreter(): Unit = {
intp = new SparkILoopInterpreter(settings, out)
}
val initializationCommands: Seq[String] = Seq(
"""
@transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) {

View file

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.repl
import scala.tools.nsc.Settings
import scala.tools.nsc.interpreter._
class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) {
self =>
override lazy val memberHandlers = new {
val intp: self.type = self
} with MemberHandlers {
import intp.global._
override def chooseHandler(member: intp.global.Tree): MemberHandler = member match {
case member: Import => new SparkImportHandler(member)
case _ => super.chooseHandler (member)
}
class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) {
override def targetType: Type = intp.global.rootMirror.getModuleIfDefined("" + expr) match {
case NoSymbol => intp.typeOfExpression("" + expr)
case sym => sym.tpe
}
private def safeIndexOf(name: Name, s: String): Int = fixIndexOf(name, pos(name, s))
private def fixIndexOf(name: Name, idx: Int): Int = if (idx == name.length) -1 else idx
private def pos(name: Name, s: String): Int = {
var i = name.pos(s.charAt(0), 0)
val sLen = s.length()
if (sLen == 1) return i
while (i + sLen <= name.length) {
var j = 1
while (s.charAt(j) == name.charAt(i + j)) {
j += 1
if (j == sLen) return i
}
i = name.pos(s.charAt(0), i + 1)
}
name.length
}
private def isFlattenedSymbol(sym: Symbol): Boolean =
sym.owner.isPackageClass &&
sym.name.containsName(nme.NAME_JOIN_STRING) &&
sym.owner.info.member(sym.name.take(
safeIndexOf(sym.name, nme.NAME_JOIN_STRING))) != NoSymbol
private def importableTargetMembers =
importableMembers(exitingTyper(targetType)).filterNot(isFlattenedSymbol).toList
def isIndividualImport(s: ImportSelector): Boolean =
s.name != nme.WILDCARD && s.rename != nme.WILDCARD
def isWildcardImport(s: ImportSelector): Boolean =
s.name == nme.WILDCARD
// non-wildcard imports
private def individualSelectors = selectors filter isIndividualImport
override val importsWildcard: Boolean = selectors exists isWildcardImport
lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = {
val selectorRenameMap =
individualSelectors.flatMap(x => x.name.bothNames zip x.rename.bothNames).toMap
importableTargetMembers flatMap (m => selectorRenameMap.get(m.name) map (m -> _))
}
override lazy val individualSymbols: List[Symbol] = importableSymbolsWithRenames map (_._1)
override lazy val wildcardSymbols: List[Symbol] =
if (importsWildcard) importableTargetMembers else Nil
}
}
object expressionTyper extends {
val repl: SparkILoopInterpreter.this.type = self
} with SparkExprTyper { }
override def symbolOfLine(code: String): global.Symbol =
expressionTyper.symbolOfLine(code)
override def typeOfExpression(expr: String, silent: Boolean): global.Type =
expressionTyper.typeOfExpression(expr, silent)
}

View file

@ -227,4 +227,14 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("error: not found: value sc", output)
}
test("spark-shell should find imported types in class constructors and extends clause") {
val output = runInterpreter("local",
"""
|import org.apache.spark.Partition
|class P(p: Partition)
|class P(val index: Int) extends Partition
""".stripMargin)
assertDoesNotContain("error: not found: type Partition", output)
}
}