[SPARK-7886] Add built-in expressions to FunctionRegistry.
This patch switches to using FunctionRegistry for built-in expressions. It is based on #6463, but with some work to simplify it along with unit tests. TODOs for future pull requests: - Use static registration so we don't need to register all functions every time we start a new SQLContext - Switch to using this in HiveContext Author: Reynold Xin <rxin@databricks.com> Author: Santiago M. Mola <santi@mola.io> Closes #6710 from rxin/udf-registry and squashes the following commits: 6930822 [Reynold Xin] Fixed Python test. b802c9a [Reynold Xin] Made UDF case insensitive. e60d815 [Reynold Xin] Made UDF case insensitive. 852f9c0 [Reynold Xin] Fixed style violation. e76a3c1 [Reynold Xin] Fixed parser. 52ddaba [Reynold Xin] Fixed compilation. ee7854f [Reynold Xin] Improved error reporting. ff906f2 [Reynold Xin] More robust constructor calling. 77b46f1 [Reynold Xin] Simplified the code. 2a2a149 [Reynold Xin] Merge pull request #6463 from smola/SPARK-7886 8616924 [Santiago M. Mola] [SPARK-7886] Add built-in expressions to FunctionRegistry.
This commit is contained in:
parent
0902a11940
commit
1b499993ad
|
@ -746,7 +746,7 @@ class DataFrame(object):
|
|||
This is a variant of :func:`select` that accepts SQL expressions.
|
||||
|
||||
>>> df.selectExpr("age * 2", "abs(age)").collect()
|
||||
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
|
||||
[Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)]
|
||||
"""
|
||||
if len(expr) == 1 and isinstance(expr[0], list):
|
||||
expr = expr[0]
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst
|
|||
|
||||
import scala.language.implicitConversions
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
|
@ -48,26 +49,21 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
|
||||
// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
|
||||
// properties via reflection the class in runtime for constructing the SqlLexical object
|
||||
protected val ABS = Keyword("ABS")
|
||||
protected val ALL = Keyword("ALL")
|
||||
protected val AND = Keyword("AND")
|
||||
protected val APPROXIMATE = Keyword("APPROXIMATE")
|
||||
protected val AS = Keyword("AS")
|
||||
protected val ASC = Keyword("ASC")
|
||||
protected val AVG = Keyword("AVG")
|
||||
protected val BETWEEN = Keyword("BETWEEN")
|
||||
protected val BY = Keyword("BY")
|
||||
protected val CASE = Keyword("CASE")
|
||||
protected val CAST = Keyword("CAST")
|
||||
protected val COALESCE = Keyword("COALESCE")
|
||||
protected val COUNT = Keyword("COUNT")
|
||||
protected val DESC = Keyword("DESC")
|
||||
protected val DISTINCT = Keyword("DISTINCT")
|
||||
protected val ELSE = Keyword("ELSE")
|
||||
protected val END = Keyword("END")
|
||||
protected val EXCEPT = Keyword("EXCEPT")
|
||||
protected val FALSE = Keyword("FALSE")
|
||||
protected val FIRST = Keyword("FIRST")
|
||||
protected val FROM = Keyword("FROM")
|
||||
protected val FULL = Keyword("FULL")
|
||||
protected val GROUP = Keyword("GROUP")
|
||||
|
@ -80,13 +76,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
protected val INTO = Keyword("INTO")
|
||||
protected val IS = Keyword("IS")
|
||||
protected val JOIN = Keyword("JOIN")
|
||||
protected val LAST = Keyword("LAST")
|
||||
protected val LEFT = Keyword("LEFT")
|
||||
protected val LIKE = Keyword("LIKE")
|
||||
protected val LIMIT = Keyword("LIMIT")
|
||||
protected val LOWER = Keyword("LOWER")
|
||||
protected val MAX = Keyword("MAX")
|
||||
protected val MIN = Keyword("MIN")
|
||||
protected val NOT = Keyword("NOT")
|
||||
protected val NULL = Keyword("NULL")
|
||||
protected val ON = Keyword("ON")
|
||||
|
@ -100,15 +92,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
protected val RLIKE = Keyword("RLIKE")
|
||||
protected val SELECT = Keyword("SELECT")
|
||||
protected val SEMI = Keyword("SEMI")
|
||||
protected val SQRT = Keyword("SQRT")
|
||||
protected val SUBSTR = Keyword("SUBSTR")
|
||||
protected val SUBSTRING = Keyword("SUBSTRING")
|
||||
protected val SUM = Keyword("SUM")
|
||||
protected val TABLE = Keyword("TABLE")
|
||||
protected val THEN = Keyword("THEN")
|
||||
protected val TRUE = Keyword("TRUE")
|
||||
protected val UNION = Keyword("UNION")
|
||||
protected val UPPER = Keyword("UPPER")
|
||||
protected val WHEN = Keyword("WHEN")
|
||||
protected val WHERE = Keyword("WHERE")
|
||||
protected val WITH = Keyword("WITH")
|
||||
|
@ -277,25 +264,36 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
)
|
||||
|
||||
protected lazy val function: Parser[Expression] =
|
||||
( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) }
|
||||
| SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) }
|
||||
| COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) }
|
||||
| COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) }
|
||||
| COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^
|
||||
{ case exps => CountDistinct(exps) }
|
||||
| APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^
|
||||
{ case exp => ApproxCountDistinct(exp) }
|
||||
| APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
|
||||
{ case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) }
|
||||
| FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) }
|
||||
| LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) }
|
||||
| AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) }
|
||||
| MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) }
|
||||
| MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) }
|
||||
| UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) }
|
||||
| LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
|
||||
| IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
|
||||
{ case c ~ t ~ f => If(c, t, f) }
|
||||
( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName =>
|
||||
if (lexical.normalizeKeyword(udfName) == "count") {
|
||||
Count(Literal(1))
|
||||
} else {
|
||||
throw new AnalysisException(s"invalid expression $udfName(*)")
|
||||
}
|
||||
}
|
||||
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
|
||||
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
|
||||
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
|
||||
lexical.normalizeKeyword(udfName) match {
|
||||
case "sum" => SumDistinct(exprs.head)
|
||||
case "count" => CountDistinct(exprs)
|
||||
}
|
||||
}
|
||||
| APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp =>
|
||||
if (lexical.normalizeKeyword(udfName) == "count") {
|
||||
ApproxCountDistinct(exp)
|
||||
} else {
|
||||
throw new AnalysisException(s"invalid function approximate $udfName")
|
||||
}
|
||||
}
|
||||
| APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
|
||||
{ case s ~ _ ~ udfName ~ _ ~ _ ~ exp =>
|
||||
if (lexical.normalizeKeyword(udfName) == "count") {
|
||||
ApproxCountDistinct(exp, s.toDouble)
|
||||
} else {
|
||||
throw new AnalysisException(s"invalid function approximate($floatLit) $udfName")
|
||||
}
|
||||
}
|
||||
| CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~
|
||||
(ELSE ~> expression).? <~ END ^^ {
|
||||
case casePart ~ altPart ~ elsePart =>
|
||||
|
@ -304,16 +302,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
} ++ elsePart
|
||||
casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches))
|
||||
}
|
||||
| (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
|
||||
{ case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
|
||||
| (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
|
||||
{ case s ~ p ~ l => Substring(s, p, l) }
|
||||
| COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) }
|
||||
| SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) }
|
||||
| ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) }
|
||||
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
|
||||
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
|
||||
)
|
||||
)
|
||||
|
||||
protected lazy val cast: Parser[Expression] =
|
||||
CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ {
|
||||
|
|
|
@ -461,7 +461,9 @@ class Analyzer(
|
|||
case q: LogicalPlan =>
|
||||
q transformExpressions {
|
||||
case u @ UnresolvedFunction(name, children) if u.childrenResolved =>
|
||||
registry.lookupFunction(name, children)
|
||||
withPosition(u) {
|
||||
registry.lookupFunction(name, children)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,24 +17,27 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import org.apache.spark.sql.catalyst.CatalystConf
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import scala.collection.mutable
|
||||
import scala.reflect.ClassTag
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
|
||||
|
||||
|
||||
/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */
|
||||
trait FunctionRegistry {
|
||||
type FunctionBuilder = Seq[Expression] => Expression
|
||||
|
||||
def registerFunction(name: String, builder: FunctionBuilder): Unit
|
||||
|
||||
@throws[AnalysisException]("If function does not exist")
|
||||
def lookupFunction(name: String, children: Seq[Expression]): Expression
|
||||
|
||||
def conf: CatalystConf
|
||||
}
|
||||
|
||||
trait OverrideFunctionRegistry extends FunctionRegistry {
|
||||
|
||||
val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis)
|
||||
private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false)
|
||||
|
||||
override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
|
||||
functionBuilders.put(name, builder)
|
||||
|
@ -45,16 +48,19 @@ trait OverrideFunctionRegistry extends FunctionRegistry {
|
|||
}
|
||||
}
|
||||
|
||||
class SimpleFunctionRegistry(val conf: CatalystConf) extends FunctionRegistry {
|
||||
class SimpleFunctionRegistry extends FunctionRegistry {
|
||||
|
||||
val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis)
|
||||
private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false)
|
||||
|
||||
override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
|
||||
functionBuilders.put(name, builder)
|
||||
}
|
||||
|
||||
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
|
||||
functionBuilders(name)(children)
|
||||
val func = functionBuilders.get(name).getOrElse {
|
||||
throw new AnalysisException(s"undefined function $name")
|
||||
}
|
||||
func(children)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,30 +76,89 @@ object EmptyFunctionRegistry extends FunctionRegistry {
|
|||
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
|
||||
throw new UnsupportedOperationException
|
||||
}
|
||||
|
||||
override def conf: CatalystConf = throw new UnsupportedOperationException
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a map with String type of key, and it also supports either key case
|
||||
* sensitive or insensitive.
|
||||
* TODO move this into util folder?
|
||||
*/
|
||||
object StringKeyHashMap {
|
||||
def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match {
|
||||
case false => new StringKeyHashMap[T](_.toLowerCase)
|
||||
case true => new StringKeyHashMap[T](identity)
|
||||
|
||||
object FunctionRegistry {
|
||||
|
||||
type FunctionBuilder = Seq[Expression] => Expression
|
||||
|
||||
val expressions: Map[String, FunctionBuilder] = Map(
|
||||
// Non aggregate functions
|
||||
expression[Abs]("abs"),
|
||||
expression[CreateArray]("array"),
|
||||
expression[Coalesce]("coalesce"),
|
||||
expression[Explode]("explode"),
|
||||
expression[Lower]("lower"),
|
||||
expression[Substring]("substr"),
|
||||
expression[Substring]("substring"),
|
||||
expression[Rand]("rand"),
|
||||
expression[Randn]("randn"),
|
||||
expression[CreateStruct]("struct"),
|
||||
expression[Sqrt]("sqrt"),
|
||||
expression[Upper]("upper"),
|
||||
|
||||
// Math functions
|
||||
expression[Acos]("acos"),
|
||||
expression[Asin]("asin"),
|
||||
expression[Atan]("atan"),
|
||||
expression[Atan2]("atan2"),
|
||||
expression[Cbrt]("cbrt"),
|
||||
expression[Ceil]("ceil"),
|
||||
expression[Cos]("cos"),
|
||||
expression[Exp]("exp"),
|
||||
expression[Expm1]("expm1"),
|
||||
expression[Floor]("floor"),
|
||||
expression[Hypot]("hypot"),
|
||||
expression[Log]("log"),
|
||||
expression[Log10]("log10"),
|
||||
expression[Log1p]("log1p"),
|
||||
expression[Pow]("pow"),
|
||||
expression[Rint]("rint"),
|
||||
expression[Signum]("signum"),
|
||||
expression[Sin]("sin"),
|
||||
expression[Sinh]("sinh"),
|
||||
expression[Tan]("tan"),
|
||||
expression[Tanh]("tanh"),
|
||||
expression[ToDegrees]("todegrees"),
|
||||
expression[ToRadians]("toradians"),
|
||||
|
||||
// aggregate functions
|
||||
expression[Average]("avg"),
|
||||
expression[Count]("count"),
|
||||
expression[First]("first"),
|
||||
expression[Last]("last"),
|
||||
expression[Max]("max"),
|
||||
expression[Min]("min"),
|
||||
expression[Sum]("sum")
|
||||
)
|
||||
|
||||
/** See usage above. */
|
||||
private def expression[T <: Expression](name: String)
|
||||
(implicit tag: ClassTag[T]): (String, FunctionBuilder) = {
|
||||
// Use the companion class to find apply methods.
|
||||
val objectClass = Class.forName(tag.runtimeClass.getName + "$")
|
||||
val companionObj = objectClass.getDeclaredField("MODULE$").get(null)
|
||||
|
||||
// See if we can find an apply that accepts Seq[Expression]
|
||||
val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption
|
||||
|
||||
val builder = (expressions: Seq[Expression]) => {
|
||||
if (varargApply.isDefined) {
|
||||
// If there is an apply method that accepts Seq[Expression], use that one.
|
||||
varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression]
|
||||
} else {
|
||||
// Otherwise, find an apply method that matches the number of arguments, and use that.
|
||||
val params = Seq.fill(expressions.size)(classOf[Expression])
|
||||
val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match {
|
||||
case Success(e) =>
|
||||
e
|
||||
case Failure(e) =>
|
||||
throw new AnalysisException(s"Invalid number of arguments for function $name")
|
||||
}
|
||||
f.invoke(companionObj, expressions : _*).asInstanceOf[Expression]
|
||||
}
|
||||
}
|
||||
(name, builder)
|
||||
}
|
||||
}
|
||||
|
||||
class StringKeyHashMap[T](normalizer: (String) => String) {
|
||||
private val base = new collection.mutable.HashMap[String, T]()
|
||||
|
||||
def apply(key: String): T = base(normalizer(key))
|
||||
|
||||
def get(key: String): Option[T] = base.get(normalizer(key))
|
||||
def put(key: String, value: T): Option[T] = base.put(normalizer(key), value)
|
||||
def remove(key: String): Option[T] = base.remove(normalizer(key))
|
||||
def iterator: Iterator[(String, T)] = base.toIterator
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,15 @@ import org.apache.spark.sql.catalyst.trees
|
|||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
/**
|
||||
* For Catalyst to work correctly, concrete implementations of [[Expression]]s must be case classes
|
||||
* whose constructor arguments are all Expressions types. In addition, if we want to support more
|
||||
* than one constructor, define those constructors explicitly as apply methods in the companion
|
||||
* object.
|
||||
*
|
||||
* See [[Substring]] for an example.
|
||||
*/
|
||||
abstract class Expression extends TreeNode[Expression] {
|
||||
self: Product =>
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.types.{DataType, DoubleType}
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.random.XORShiftRandom
|
||||
|
@ -46,11 +47,29 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
|
|||
}
|
||||
|
||||
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
|
||||
case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
|
||||
case class Rand(seed: Long) extends RDG(seed) {
|
||||
override def eval(input: Row): Double = rng.nextDouble()
|
||||
}
|
||||
|
||||
object Rand {
|
||||
def apply(): Rand = apply(Utils.random.nextLong())
|
||||
|
||||
def apply(seed: Expression): Rand = apply(seed match {
|
||||
case IntegerLiteral(s) => s
|
||||
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
|
||||
})
|
||||
}
|
||||
|
||||
/** Generate a random column with i.i.d. gaussian random distribution. */
|
||||
case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
|
||||
case class Randn(seed: Long) extends RDG(seed) {
|
||||
override def eval(input: Row): Double = rng.nextGaussian()
|
||||
}
|
||||
|
||||
object Randn {
|
||||
def apply(): Randn = apply(Utils.random.nextLong())
|
||||
|
||||
def apply(seed: Expression): Randn = apply(seed match {
|
||||
case IntegerLiteral(s) => s
|
||||
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
|
||||
})
|
||||
}
|
||||
|
|
|
@ -227,6 +227,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
|
|||
override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
|
||||
|
||||
override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
|
||||
|
||||
override def dataType: DataType = {
|
||||
if (!resolved) {
|
||||
throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
|
||||
|
@ -287,3 +288,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
|
|||
case _ => s"SUBSTR($str, $pos, $len)"
|
||||
}
|
||||
}
|
||||
|
||||
object Substring {
|
||||
def apply(str: Expression, pos: Expression): Substring = {
|
||||
apply(str, pos, Literal(Integer.MAX_VALUE))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.util
|
||||
|
||||
/**
|
||||
* Build a map with String type of key, and it also supports either key case
|
||||
* sensitive or insensitive.
|
||||
*/
|
||||
object StringKeyHashMap {
|
||||
def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match {
|
||||
case false => new StringKeyHashMap[T](_.toLowerCase)
|
||||
case true => new StringKeyHashMap[T](identity)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class StringKeyHashMap[T](normalizer: (String) => String) {
|
||||
private val base = new collection.mutable.HashMap[String, T]()
|
||||
|
||||
def apply(key: String): T = base(normalizer(key))
|
||||
|
||||
def get(key: String): Option[T] = base.get(normalizer(key))
|
||||
|
||||
def put(key: String, value: T): Option[T] = base.put(normalizer(key), value)
|
||||
|
||||
def remove(key: String): Option[T] = base.remove(normalizer(key))
|
||||
|
||||
def iterator: Iterator[(String, T)] = base.toIterator
|
||||
}
|
|
@ -120,7 +120,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
|
||||
// TODO how to handle the temp function per user session?
|
||||
@transient
|
||||
protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf)
|
||||
protected[sql] lazy val functionRegistry: FunctionRegistry = {
|
||||
val fr = new SimpleFunctionRegistry
|
||||
FunctionRegistry.expressions.foreach { case (name, func) => fr.registerFunction(name, func) }
|
||||
fr
|
||||
}
|
||||
|
||||
@transient
|
||||
protected[sql] lazy val analyzer: Analyzer =
|
||||
|
|
|
@ -25,6 +25,48 @@ class UDFSuite extends QueryTest {
|
|||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
|
||||
test("built-in fixed arity expressions") {
|
||||
val df = ctx.emptyDataFrame
|
||||
df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)")
|
||||
}
|
||||
|
||||
test("built-in vararg expressions") {
|
||||
val df = Seq((1, 2)).toDF("a", "b")
|
||||
df.selectExpr("array(a, b)")
|
||||
df.selectExpr("struct(a, b)")
|
||||
}
|
||||
|
||||
test("built-in expressions with multiple constructors") {
|
||||
val df = Seq(("abcd", 2)).toDF("a", "b")
|
||||
df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect()
|
||||
}
|
||||
|
||||
test("count") {
|
||||
val df = Seq(("abcd", 2)).toDF("a", "b")
|
||||
df.selectExpr("count(a)")
|
||||
}
|
||||
|
||||
test("count distinct") {
|
||||
val df = Seq(("abcd", 2)).toDF("a", "b")
|
||||
df.selectExpr("count(distinct a)")
|
||||
}
|
||||
|
||||
test("error reporting for incorrect number of arguments") {
|
||||
val df = ctx.emptyDataFrame
|
||||
val e = intercept[AnalysisException] {
|
||||
df.selectExpr("substr('abcd', 2, 3, 4)")
|
||||
}
|
||||
assert(e.getMessage.contains("arguments"))
|
||||
}
|
||||
|
||||
test("error reporting for undefined functions") {
|
||||
val df = ctx.emptyDataFrame
|
||||
val e = intercept[AnalysisException] {
|
||||
df.selectExpr("a_function_that_does_not_exist()")
|
||||
}
|
||||
assert(e.getMessage.contains("undefined function"))
|
||||
}
|
||||
|
||||
test("Simple UDF") {
|
||||
ctx.udf.register("strLenScala", (_: String).length)
|
||||
assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
|
||||
|
|
|
@ -39,13 +39,12 @@ import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
|
|||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry}
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand}
|
||||
import org.apache.spark.sql.hive.client._
|
||||
import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
|
||||
import org.apache.spark.sql.sources.DataSourceStrategy
|
||||
import org.apache.spark.sql.catalyst.CatalystConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
@ -374,10 +373,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
|
|||
|
||||
// Note that HiveUDFs will be overridden by functions registered in this context.
|
||||
@transient
|
||||
override protected[sql] lazy val functionRegistry =
|
||||
new HiveFunctionRegistry with OverrideFunctionRegistry {
|
||||
override def conf: CatalystConf = currentSession().conf
|
||||
}
|
||||
override protected[sql] lazy val functionRegistry: FunctionRegistry =
|
||||
new HiveFunctionRegistry with OverrideFunctionRegistry
|
||||
|
||||
/* An analyzer that uses the Hive metastore. */
|
||||
@transient
|
||||
|
|
|
@ -17,11 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.hive
|
||||
|
||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
|
||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector}
|
||||
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
|
||||
|
@ -30,8 +27,11 @@ import org.apache.hadoop.hive.ql.exec._
|
|||
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
|
||||
import org.apache.hadoop.hive.ql.udf.generic._
|
||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
|
||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
|
||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.analysis
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -40,20 +40,18 @@ import org.apache.spark.sql.catalyst.rules.Rule
|
|||
import org.apache.spark.sql.hive.HiveShim._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/* Implicit conversions */
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
private[hive] abstract class HiveFunctionRegistry
|
||||
extends analysis.FunctionRegistry with HiveInspectors {
|
||||
|
||||
def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
|
||||
|
||||
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
|
||||
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
|
||||
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
|
||||
// not always serializable.
|
||||
val functionInfo: FunctionInfo =
|
||||
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
|
||||
sys.error(s"Couldn't find function $name"))
|
||||
throw new AnalysisException(s"undefined function $name"))
|
||||
|
||||
val functionClassName = functionInfo.getFunctionClass.getName
|
||||
|
||||
|
|
Loading…
Reference in a new issue