[SPARK-8349] [SQL] Use expression constructors (rather than apply) in FunctionRegistry

Author: Reynold Xin <rxin@databricks.com>

Closes #6806 from rxin/gs and squashes the following commits:

ed1aebb [Reynold Xin] Fixed style.
c7fc3e6 [Reynold Xin] [SPARK-8349][SQL] Use expression constructors (rather than apply) in FunctionRegistry
This commit is contained in:
Reynold Xin 2015-06-13 18:22:17 -07:00 committed by Michael Armbrust
parent a138953391
commit 2d71ba4c8a
5 changed files with 22 additions and 32 deletions

View file

@ -158,27 +158,23 @@ object FunctionRegistry {
/** See usage above. */
private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, FunctionBuilder) = {
// Use the companion class to find apply methods.
val objectClass = Class.forName(tag.runtimeClass.getName + "$")
val companionObj = objectClass.getDeclaredField("MODULE$").get(null)
// See if we can find an apply that accepts Seq[Expression]
val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption
// See if we can find a constructor that accepts Seq[Expression]
val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption
val builder = (expressions: Seq[Expression]) => {
if (varargApply.isDefined) {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression]
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
} else {
// Otherwise, find an apply method that matches the number of arguments, and use that.
// Otherwise, find an ctor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match {
val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match {
case Success(e) =>
e
case Failure(e) =>
throw new AnalysisException(s"Invalid number of arguments for function $name")
}
f.invoke(companionObj, expressions : _*).asInstanceOf[Expression]
f.newInstance(expressions : _*).asInstanceOf[Expression]
}
}
(name, builder)

View file

@ -27,8 +27,7 @@ import org.apache.spark.sql.types._
/**
* If an expression wants to be exposed in the function registry (so users can call it with
* "name(arguments...)", the concrete implementation must be a case class whose constructor
* arguments are all Expressions types. In addition, if it needs to support more than one
* constructor, define those constructors explicitly as apply methods in the companion object.
* arguments are all Expressions types.
*
* See [[Substring]] for an example.
*/

View file

@ -49,12 +49,10 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
case class Rand(seed: Long) extends RDG(seed) {
override def eval(input: InternalRow): Double = rng.nextDouble()
}
object Rand {
def apply(): Rand = apply(Utils.random.nextLong())
def this() = this(Utils.random.nextLong())
def apply(seed: Expression): Rand = apply(seed match {
def this(seed: Expression) = this(seed match {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})
@ -63,12 +61,10 @@ object Rand {
/** Generate a random column with i.i.d. gaussian random distribution. */
case class Randn(seed: Long) extends RDG(seed) {
override def eval(input: InternalRow): Double = rng.nextGaussian()
}
object Randn {
def apply(): Randn = apply(Utils.random.nextLong())
def this() = this(Utils.random.nextLong())
def apply(seed: Expression): Randn = apply(seed match {
def this(seed: Expression) = this(seed match {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.Substring
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -225,6 +226,10 @@ case class EndsWith(left: Expression, right: Expression)
case class Substring(str: Expression, pos: Expression, len: Expression)
extends Expression with ExpectsInputTypes {
def this(str: Expression, pos: Expression) = {
this(str, pos, Literal(Integer.MAX_VALUE))
}
override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
@ -290,12 +295,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
}
}
object Substring {
def apply(str: Expression, pos: Expression): Substring = {
apply(str, pos, Literal(Integer.MAX_VALUE))
}
}
/**
* A function that return the length of the given string expression.
*/

View file

@ -344,11 +344,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
val defaultCtor =
getClass.getConstructors
.find(_.getParameterTypes.size != 0)
.headOption
.getOrElse(sys.error(s"No valid constructor for $nodeName"))
val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
}
val defaultCtor = ctors.maxBy(_.getParameterTypes.size)
try {
CurrentOrigin.withOrigin(origin) {