[SPARK-8349] [SQL] Use expression constructors (rather than apply) in FunctionRegistry
Author: Reynold Xin <rxin@databricks.com> Closes #6806 from rxin/gs and squashes the following commits: ed1aebb [Reynold Xin] Fixed style. c7fc3e6 [Reynold Xin] [SPARK-8349][SQL] Use expression constructors (rather than apply) in FunctionRegistry
This commit is contained in:
parent
a138953391
commit
2d71ba4c8a
|
@ -158,27 +158,23 @@ object FunctionRegistry {
|
|||
/** See usage above. */
|
||||
private def expression[T <: Expression](name: String)
|
||||
(implicit tag: ClassTag[T]): (String, FunctionBuilder) = {
|
||||
// Use the companion class to find apply methods.
|
||||
val objectClass = Class.forName(tag.runtimeClass.getName + "$")
|
||||
val companionObj = objectClass.getDeclaredField("MODULE$").get(null)
|
||||
|
||||
// See if we can find an apply that accepts Seq[Expression]
|
||||
val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption
|
||||
|
||||
// See if we can find a constructor that accepts Seq[Expression]
|
||||
val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption
|
||||
val builder = (expressions: Seq[Expression]) => {
|
||||
if (varargApply.isDefined) {
|
||||
if (varargCtor.isDefined) {
|
||||
// If there is an apply method that accepts Seq[Expression], use that one.
|
||||
varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression]
|
||||
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
|
||||
} else {
|
||||
// Otherwise, find an apply method that matches the number of arguments, and use that.
|
||||
// Otherwise, find an ctor method that matches the number of arguments, and use that.
|
||||
val params = Seq.fill(expressions.size)(classOf[Expression])
|
||||
val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match {
|
||||
val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match {
|
||||
case Success(e) =>
|
||||
e
|
||||
case Failure(e) =>
|
||||
throw new AnalysisException(s"Invalid number of arguments for function $name")
|
||||
}
|
||||
f.invoke(companionObj, expressions : _*).asInstanceOf[Expression]
|
||||
f.newInstance(expressions : _*).asInstanceOf[Expression]
|
||||
}
|
||||
}
|
||||
(name, builder)
|
||||
|
|
|
@ -27,8 +27,7 @@ import org.apache.spark.sql.types._
|
|||
/**
|
||||
* If an expression wants to be exposed in the function registry (so users can call it with
|
||||
* "name(arguments...)", the concrete implementation must be a case class whose constructor
|
||||
* arguments are all Expressions types. In addition, if it needs to support more than one
|
||||
* constructor, define those constructors explicitly as apply methods in the companion object.
|
||||
* arguments are all Expressions types.
|
||||
*
|
||||
* See [[Substring]] for an example.
|
||||
*/
|
||||
|
|
|
@ -49,12 +49,10 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
|
|||
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
|
||||
case class Rand(seed: Long) extends RDG(seed) {
|
||||
override def eval(input: InternalRow): Double = rng.nextDouble()
|
||||
}
|
||||
|
||||
object Rand {
|
||||
def apply(): Rand = apply(Utils.random.nextLong())
|
||||
def this() = this(Utils.random.nextLong())
|
||||
|
||||
def apply(seed: Expression): Rand = apply(seed match {
|
||||
def this(seed: Expression) = this(seed match {
|
||||
case IntegerLiteral(s) => s
|
||||
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
|
||||
})
|
||||
|
@ -63,12 +61,10 @@ object Rand {
|
|||
/** Generate a random column with i.i.d. gaussian random distribution. */
|
||||
case class Randn(seed: Long) extends RDG(seed) {
|
||||
override def eval(input: InternalRow): Double = rng.nextGaussian()
|
||||
}
|
||||
|
||||
object Randn {
|
||||
def apply(): Randn = apply(Utils.random.nextLong())
|
||||
def this() = this(Utils.random.nextLong())
|
||||
|
||||
def apply(seed: Expression): Randn = apply(seed match {
|
||||
def this(seed: Expression) = this(seed match {
|
||||
case IntegerLiteral(s) => s
|
||||
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
|
||||
})
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import java.util.regex.Pattern
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
|
||||
import org.apache.spark.sql.catalyst.expressions.Substring
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -225,6 +226,10 @@ case class EndsWith(left: Expression, right: Expression)
|
|||
case class Substring(str: Expression, pos: Expression, len: Expression)
|
||||
extends Expression with ExpectsInputTypes {
|
||||
|
||||
def this(str: Expression, pos: Expression) = {
|
||||
this(str, pos, Literal(Integer.MAX_VALUE))
|
||||
}
|
||||
|
||||
override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
|
||||
|
||||
override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
|
||||
|
@ -290,12 +295,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
|
|||
}
|
||||
}
|
||||
|
||||
object Substring {
|
||||
def apply(str: Expression, pos: Expression): Substring = {
|
||||
apply(str, pos, Literal(Integer.MAX_VALUE))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A function that return the length of the given string expression.
|
||||
*/
|
||||
|
|
|
@ -344,11 +344,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
|
|||
* @param newArgs the new product arguments.
|
||||
*/
|
||||
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
|
||||
val defaultCtor =
|
||||
getClass.getConstructors
|
||||
.find(_.getParameterTypes.size != 0)
|
||||
.headOption
|
||||
.getOrElse(sys.error(s"No valid constructor for $nodeName"))
|
||||
val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
|
||||
if (ctors.isEmpty) {
|
||||
sys.error(s"No valid constructor for $nodeName")
|
||||
}
|
||||
val defaultCtor = ctors.maxBy(_.getParameterTypes.size)
|
||||
|
||||
try {
|
||||
CurrentOrigin.withOrigin(origin) {
|
||||
|
|
Loading…
Reference in a new issue