[SPARK-34981][SQL] Implement V2 function resolution and evaluation

Co-Authored-By: Chao Sun <sunchaoapple.com>
Co-Authored-By: Ryan Blue <rbluenetflix.com>

### What changes were proposed in this pull request?

This implements function resolution and evaluation for functions registered through V2 FunctionCatalog [SPARK-27658](https://issues.apache.org/jira/browse/SPARK-27658). In particular:
- Added documentation for how to define the "magic method" in `ScalarFunction`.
- Added a new expression `ApplyFunctionExpression` which evaluates input by delegating to `ScalarFunction.produceResult` method.
- added a new expression `V2Aggregator` which is a type of `TypedImperativeAggregate`. It's a wrapper of V2 `AggregateFunction` and mostly delegate methods to the implementation of the latter. It also uses plain Java serde for intermediate state.
- Added function resolution logic for `ScalarFunction` and `AggregateFunction` in `Analyzer`.
  + For `ScalarFunction` this checks if the magic method is implemented through Java reflection, and create a `Invoke` expression if so. Otherwise, it checks if the default `produceResult` is overridden. If so, it creates a `ApplyFunctionExpression` which evaluates through `InternalRow`. Otherwise an analysis exception is thrown.
 + For `AggregateFunction`, this checks if the `update` method is overridden. If so, it converts it to `V2Aggregator`. Otherwise an analysis exception is thrown similar to the case of `ScalarFunction`.
- Extended existing `InMemoryTableCatalog` to add the function catalog capability. Also renamed it to `InMemoryCatalog` since it no longer only covers tables.

**Note**: this currently can successfully detect whether a subclass overrides the default `produceResult` or `update` method from the parent interface **only for Java implementations**. It seems in Scala it's hard to differentiate whether a subclass overrides a default method from its parent interface. In this case, it will be a runtime error instead of analysis error.

A few TODOs:
- Extend `V2SessionCatalog` with function catalog. This seems a little tricky since API such V2 `FunctionCatalog`'s `loadFunction` is different from V1 `SessionCatalog`'s `lookupFunction`.
- Add magic method for `AggregateFunction`.
- Type coercion when looking up functions

### Why are the changes needed?

As V2 FunctionCatalog APIs are finalized, we should integrate it with function resolution and evaluation process so that they are actually useful.

### Does this PR introduce _any_ user-facing change?

Yes, now a function exposed through V2 FunctionCatalog can be analyzed and evaluated.

### How was this patch tested?

Added new unit tests.

Closes #32082 from sunchao/resolve-func-v2.

Lead-authored-by: Chao Sun <sunchao@apple.com>
Co-authored-by: Chao Sun <sunchao@apache.org>
Co-authored-by: Chao Sun <sunchao@uber.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Chao Sun 2021-04-28 17:21:49 +00:00 committed by Wenchen Fan
parent 0bcf348438
commit 86d3bb5f7d
21 changed files with 1163 additions and 76 deletions

View file

@ -25,13 +25,12 @@ import java.io.Serializable;
/**
* Interface for a function that produces a result value by aggregating over multiple input rows.
* <p>
* For each input row, Spark will call an update method that corresponds to the
* {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
* Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
* update with {@link InternalRow}.
* <p>
* The JVM type of result values produced by this function must be the type used by Spark's
* For each input row, Spark will call the {@link #update} method which should evaluate the row
* and update the aggregation state. The JVM type of result values produced by
* {@link #produceResult} must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
* Please refer to class documentation of {@link ScalarFunction} for the mapping between
* {@link DataType} and the JVM type.
* <p>
* All implementations must support partial aggregation by implementing merge so that Spark can
* partially aggregate and shuffle intermediate results, instead of shuffling all rows for an
@ -68,9 +67,7 @@ public interface AggregateFunction<S extends Serializable, R> extends BoundFunct
* @param input an input row
* @return updated aggregation state
*/
default S update(S state, InternalRow input) {
throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update");
}
S update(S state, InternalRow input);
/**
* Merge two partial aggregation states.

View file

@ -23,17 +23,67 @@ import org.apache.spark.sql.types.DataType;
/**
* Interface for a function that produces a result value for each input row.
* <p>
* For each input row, Spark will call a produceResult method that corresponds to the
* {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
* Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
* {@link #produceResult(InternalRow)}.
* To evaluate each input row, Spark will first try to lookup and use a "magic method" (described
* below) through Java reflection. If the method is not found, Spark will call
* {@link #produceResult(InternalRow)} as a fallback approach.
* <p>
* The JVM type of result values produced by this function must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
* <p>
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws
* {@link UnsupportedOperationException}. Users can choose to override this method, or implement
* a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters
* instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java
* reflection and will also provide better performance in general, due to optimizations such as
* codegen, removal of Java boxing, etc.
*
* For example, a scalar UDF for adding two integers can be defined as follow with the magic
* method approach:
*
* <pre>
* public class IntegerAdd implements{@code ScalarFunction<Integer>} {
* public int invoke(int left, int right) {
* return left + right;
* }
* }
* </pre>
* In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over
* {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by
* first converting the actual input SQL data types to their corresponding Java types following
* the mapping defined below, and then checking if there is a matching method from all the
* declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and
* the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}.
* <p>
* The following are the mapping from {@link DataType SQL data type} to Java type through
* the magic method approach:
* <ul>
* <li>{@link org.apache.spark.sql.types.BooleanType}: {@code boolean}</li>
* <li>{@link org.apache.spark.sql.types.ByteType}: {@code byte}</li>
* <li>{@link org.apache.spark.sql.types.ShortType}: {@code short}</li>
* <li>{@link org.apache.spark.sql.types.IntegerType}: {@code int}</li>
* <li>{@link org.apache.spark.sql.types.LongType}: {@code long}</li>
* <li>{@link org.apache.spark.sql.types.FloatType}: {@code float}</li>
* <li>{@link org.apache.spark.sql.types.DoubleType}: {@code double}</li>
* <li>{@link org.apache.spark.sql.types.StringType}:
* {@link org.apache.spark.unsafe.types.UTF8String}</li>
* <li>{@link org.apache.spark.sql.types.DateType}: {@code int}</li>
* <li>{@link org.apache.spark.sql.types.TimestampType}: {@code long}</li>
* <li>{@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}</li>
* <li>{@link org.apache.spark.sql.types.DayTimeIntervalType}: {@code long}</li>
* <li>{@link org.apache.spark.sql.types.YearMonthIntervalType}: {@code int}</li>
* <li>{@link org.apache.spark.sql.types.DecimalType}:
* {@link org.apache.spark.sql.types.Decimal}</li>
* <li>{@link org.apache.spark.sql.types.StructType}: {@link InternalRow}</li>
* <li>{@link org.apache.spark.sql.types.ArrayType}:
* {@link org.apache.spark.sql.catalyst.util.ArrayData}</li>
* <li>{@link org.apache.spark.sql.types.MapType}:
* {@link org.apache.spark.sql.catalyst.util.MapData}</li>
* </ul>
*
* @param <R> the JVM type of result values
*/
public interface ScalarFunction<R> extends BoundFunction {
String MAGIC_METHOD_NAME = "invoke";
/**
* Applies the function to an input row to produce a value.

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
import java.lang.reflect.Method
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
@ -29,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions.{FrameLessOffsetWindowFunction, _}
import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _}
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects._
@ -44,6 +45,8 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@ -281,7 +284,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables ::
ResolveHigherOrderFunctions(v1SessionCatalog) ::
ResolveHigherOrderFunctions(catalogManager) ::
ResolveLambdaVariables ::
ResolveTimeZone ::
ResolveRandomSeed ::
@ -895,9 +898,10 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}
// If we are resolving relations insides views, we need to expand single-part relation names with
// the current catalog and namespace of when the view was created.
private def expandRelationName(nameParts: Seq[String]): Seq[String] = {
// If we are resolving database objects (relations, functions, etc.) insides views, we may need to
// expand single or multi-part identifiers with the current catalog and namespace of when the
// view was created.
private def expandIdentifier(nameParts: Seq[String]): Seq[String] = {
if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts
if (nameParts.length == 1) {
@ -1040,7 +1044,7 @@ class Analyzer(override val catalogManager: CatalogManager)
identifier: Seq[String],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] =
expandRelationName(identifier) match {
expandIdentifier(identifier) match {
case NonSessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) =>
@ -1153,7 +1157,7 @@ class Analyzer(override val catalogManager: CatalogManager)
}
private def lookupTableOrView(identifier: Seq[String]): Option[LogicalPlan] = {
expandRelationName(identifier) match {
expandIdentifier(identifier) match {
case SessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident).map {
case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW =>
@ -1173,7 +1177,7 @@ class Analyzer(override val catalogManager: CatalogManager)
identifier: Seq[String],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] = {
expandRelationName(identifier) match {
expandIdentifier(identifier) match {
case SessionCatalogAndIdentifier(catalog, ident) =>
lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map {
case v1Table: V1Table =>
@ -1569,8 +1573,7 @@ class Analyzer(override val catalogManager: CatalogManager)
// results and confuse users if there is any null values. For count(t1.*, t2.*), it is
// still allowed, since it's well-defined in spark.
if (!conf.allowStarWithSingleTableIdentifierInCount &&
f1.name.database.isEmpty &&
f1.name.funcName == "count" &&
f1.nameParts == Seq("count") &&
f1.arguments.length == 1) {
f1.arguments.foreach {
case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) =>
@ -1958,17 +1961,19 @@ class Analyzer(override val catalogManager: CatalogManager)
override def apply(plan: LogicalPlan): LogicalPlan = {
val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]()
plan.resolveExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) =>
if (externalFunctionNameSet.contains(normalizeFuncName(ident)) ||
v1SessionCatalog.isRegisteredFunction(ident)) {
f
case f: UnresolvedFunction =>
} else if (v1SessionCatalog.isPersistentFunction(ident)) {
externalFunctionNameSet.add(normalizeFuncName(ident))
f
} else {
withPosition(f) {
throw new NoSuchFunctionException(
f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
f.name.funcName)
ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
ident.funcName)
}
}
}
}
@ -2016,9 +2021,10 @@ class Analyzer(override val catalogManager: CatalogManager)
name, other.getClass.getCanonicalName)
}
}
case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter, ignoreNulls) =>
withPosition(u) {
v1SessionCatalog.lookupFunction(funcId, arguments) match {
case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, isDistinct, filter,
ignoreNulls) => withPosition(u) {
v1SessionCatalog.lookupFunction(ident, arguments) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
@ -2096,6 +2102,120 @@ class Analyzer(override val catalogManager: CatalogManager)
other
}
}
case u @ UnresolvedFunction(nameParts, arguments, isDistinct, filter, ignoreNulls) =>
withPosition(u) {
expandIdentifier(nameParts) match {
case NonSessionCatalogAndIdentifier(catalog, ident) =>
if (!catalog.isFunctionCatalog) {
throw new AnalysisException(s"Trying to lookup function '$ident' in " +
s"catalog '${catalog.name()}', but it is not a FunctionCatalog.")
}
val unbound = catalog.asFunctionCatalog.loadFunction(ident)
val inputType = StructType(arguments.zipWithIndex.map {
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
})
val bound = try {
unbound.bind(inputType)
} catch {
case unsupported: UnsupportedOperationException =>
throw new AnalysisException(s"Function '${unbound.name}' cannot process " +
s"input: (${arguments.map(_.dataType.simpleString).mkString(", ")}): " +
unsupported.getMessage, cause = Some(unsupported))
}
bound match {
case scalarFunc: ScalarFunction[_] =>
processV2ScalarFunction(scalarFunc, inputType, arguments, isDistinct,
filter, ignoreNulls)
case aggFunc: V2AggregateFunction[_, _] =>
processV2AggregateFunction(aggFunc, arguments, isDistinct, filter,
ignoreNulls)
case _ =>
failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction" +
s" or AggregateFunction")
}
case _ => u
}
}
}
}
private def processV2ScalarFunction(
scalarFunc: ScalarFunction[_],
inputType: StructType,
arguments: Seq[Expression],
isDistinct: Boolean,
filter: Option[Expression],
ignoreNulls: Boolean): Expression = {
if (isDistinct) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "DISTINCT")
} else if (filter.isDefined) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "FILTER clause")
} else if (ignoreNulls) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "IGNORE NULLS")
} else {
// TODO: implement type coercion by looking at input type from the UDF. We
// may also want to check if the parameter types from the magic method
// match the input type through `BoundFunction.inputTypes`.
val argClasses = inputType.fields.map(_.dataType)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, returnNullable = scalarFunc.isResultNullable)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
// defined in Java, the method can still be found from
// `getDeclaredMethod`.
// since `inputType` is a `StructType`, it is mapped to a `InternalRow`
// which we can use to lookup the `produceResult` method.
findMethod(scalarFunc, "produceResult", Seq(inputType)) match {
case Some(_) =>
ApplyFunctionExpression(scalarFunc, arguments)
case None =>
failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" +
s" magic method nor override 'produceResult'")
}
}
}
}
private def processV2AggregateFunction(
aggFunc: V2AggregateFunction[_, _],
arguments: Seq[Expression],
isDistinct: Boolean,
filter: Option[Expression],
ignoreNulls: Boolean): Expression = {
if (ignoreNulls) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
aggFunc.name(), "IGNORE NULLS")
}
val aggregator = V2Aggregator(aggFunc, arguments)
AggregateExpression(aggregator, Complete, isDistinct, filter)
}
/**
* Check if the input `fn` implements the given `methodName` with parameter types specified
* via `inputType`.
*/
private def findMethod(
fn: BoundFunction,
methodName: String,
inputType: Seq[DataType]): Option[Method] = {
val cls = fn.getClass
try {
val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass)
Some(cls.getDeclaredMethod(methodName, argClasses: _*))
} catch {
case _: NoSuchMethodException =>
None
}
}
}

View file

@ -17,10 +17,10 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.DataType
@ -30,13 +30,14 @@ import org.apache.spark.sql.types.DataType
* so we need to resolve higher order function when all children are either resolved or a lambda
* function.
*/
case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] {
case class ResolveHigherOrderFunctions(catalogManager: CatalogManager)
extends Rule[LogicalPlan] with LookupCatalog {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case u @ UnresolvedFunction(fn, children, false, filter, ignoreNulls)
case u @ UnresolvedFunction(AsFunctionIdentifier(ident), children, false, filter, ignoreNulls)
if hasLambdaAndResolvedArguments(children) =>
withPosition(u) {
catalog.lookupFunction(fn, children) match {
catalogManager.v1SessionCatalog.lookupFunction(ident, children) match {
case func: HigherOrderFunction =>
filter.foreach(_.failAnalysis("FILTER predicate specified, " +
s"but ${func.prettyName} is not an aggregate function"))

View file

@ -269,12 +269,13 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio
}
case class UnresolvedFunction(
name: FunctionIdentifier,
nameParts: Seq[String],
arguments: Seq[Expression],
isDistinct: Boolean,
filter: Option[Expression] = None,
ignoreNulls: Boolean = false)
extends Expression with Unevaluable {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
override def children: Seq[Expression] = arguments ++ filter.toSeq
@ -282,10 +283,10 @@ case class UnresolvedFunction(
override def nullable: Boolean = throw new UnresolvedException("nullable")
override lazy val resolved = false
override def prettyName: String = name.unquotedString
override def prettyName: String = nameParts.quoted
override def toString: String = {
val distinct = if (isDistinct) "distinct " else ""
s"'$name($distinct${children.mkString(", ")})"
s"'${nameParts.quoted}($distinct${children.mkString(", ")})"
}
override protected def withNewChildrenInternal(
@ -299,8 +300,17 @@ case class UnresolvedFunction(
}
object UnresolvedFunction {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
def apply(
name: FunctionIdentifier,
arguments: Seq[Expression],
isDistinct: Boolean): UnresolvedFunction = {
UnresolvedFunction(name.asMultipart, arguments, isDistinct)
}
def apply(name: String, arguments: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = {
UnresolvedFunction(FunctionIdentifier(name, None), arguments, isDistinct)
UnresolvedFunction(Seq(name), arguments, isDistinct)
}
}

View file

@ -1576,6 +1576,8 @@ class SessionCatalog(
name: FunctionIdentifier,
children: Seq[Expression],
registry: FunctionRegistryBase[T]): T = synchronized {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
// Note: the implementation of this function is a little bit convoluted.
// We probably shouldn't use a single FunctionRegistry to register all three kinds of functions
// (built-in, temp, and external).
@ -1598,7 +1600,8 @@ class SessionCatalog(
case Seq() => getCurrentDatabase
case Seq(_, db) => db
case Seq(catalog, namespace @ _*) =>
throw QueryCompilationErrors.v2CatalogNotSupportFunctionError(catalog, namespace)
throw new IllegalStateException(s"[BUG] unexpected v2 catalog: $catalog, and " +
s"namespace: ${namespace.quoted} in v1 function lookup")
}
// If the name itself is not qualified, add the current database to it.

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.types.DataType
case class ApplyFunctionExpression(
function: ScalarFunction[_],
children: Seq[Expression]) extends Expression with UserDefinedExpression with CodegenFallback {
override def nullable: Boolean = function.isResultNullable
override def name: String = function.name()
override def dataType: DataType = function.resultType()
private lazy val reusedRow = new GenericInternalRow(children.size)
/** Returns the result of evaluating this expression on a given input Row */
override def eval(input: InternalRow): Any = {
children.zipWithIndex.foreach {
case (expr, pos) =>
reusedRow.update(pos, expr.eval(input))
}
function.produceResult(reusedRow)
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)
}

View file

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction}
import org.apache.spark.sql.types.DataType
case class V2Aggregator[BUF <: java.io.Serializable, OUT](
aggrFunc: V2AggregateFunction[BUF, OUT],
children: Seq[Expression],
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[BUF] {
private[this] lazy val inputProjection = UnsafeProjection.create(children)
override def nullable: Boolean = aggrFunc.isResultNullable
override def dataType: DataType = aggrFunc.resultType()
override def createAggregationBuffer(): BUF = aggrFunc.newAggregationState()
override def update(buffer: BUF, input: InternalRow): BUF = {
aggrFunc.update(buffer, inputProjection(input))
}
override def merge(buffer: BUF, input: BUF): BUF = aggrFunc.merge(buffer, input)
override def eval(buffer: BUF): Any = {
aggrFunc.produceResult(buffer)
}
override def serialize(buffer: BUF): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val out = new ObjectOutputStream(bos)
out.writeObject(buffer)
out.close()
bos.toByteArray
}
override def deserialize(bytes: Array[Byte]): BUF = {
val in = new ObjectInputStream(new ByteArrayInputStream(bytes))
in.readObject().asInstanceOf[BUF]
}
def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): V2Aggregator[BUF, OUT] =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): V2Aggregator[BUF, OUT] =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)
}

View file

@ -1816,7 +1816,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
val ignoreNulls =
Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE).getOrElse(false)
val function = UnresolvedFunction(
getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter, ignoreNulls)
getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls)
// Check if the function is evaluated in a windowed context.
ctx.windowSpec match {
@ -1828,15 +1828,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
}
}
/**
* Create a function database (optional) and name pair, for multipartIdentifier.
* This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS.
*/
protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = {
visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq)
}
/**
* Create a function database (optional) and name pair.
*/
@ -1867,6 +1858,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
}
}
protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = {
if (ctx.qualifiedName != null) {
ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq
} else {
Seq(ctx.getText)
}
}
/**
* Create an [[LambdaFunction]].
*/

View file

@ -83,12 +83,36 @@ private[sql] object CatalogV2Implicits {
throw new AnalysisException(
s"Cannot use catalog ${plugin.name}: does not support namespaces")
}
def isFunctionCatalog: Boolean = plugin match {
case _: FunctionCatalog => true
case _ => false
}
def asFunctionCatalog: FunctionCatalog = plugin match {
case functionCatalog: FunctionCatalog =>
functionCatalog
case _ =>
throw new AnalysisException(
s"Cannot use catalog '${plugin.name}': not a FunctionCatalog")
}
}
implicit class NamespaceHelper(namespace: Array[String]) {
def quoted: String = namespace.map(quoteIfNeeded).mkString(".")
}
implicit class FunctionIdentifierHelper(ident: FunctionIdentifier) {
def asMultipart: Seq[String] = {
ident.database match {
case Some(db) =>
Seq(db, ident.funcName)
case _ =>
Seq(ident.funcName)
}
}
}
implicit class IdentifierHelper(ident: Identifier) {
def quoted: String = {
if (ident.namespace.nonEmpty) {
@ -132,6 +156,14 @@ private[sql] object CatalogV2Implicits {
s"$quoted is not a valid TableIdentifier as it has more than 2 name parts.")
}
def asFunctionIdentifier: FunctionIdentifier = parts match {
case Seq(funcName) => FunctionIdentifier(funcName)
case Seq(dbName, funcName) => FunctionIdentifier(funcName, Some(dbName))
case _ =>
throw new AnalysisException(
s"$quoted is not a valid FunctionIdentifier as it has more than 2 name parts.")
}
def quoted: String = parts.map(quoteIfNeeded).mkString(".")
}

View file

@ -154,6 +154,28 @@ private[sql] trait LookupCatalog extends Logging {
}
}
object AsFunctionIdentifier {
def unapply(parts: Seq[String]): Option[FunctionIdentifier] = {
def namesToFunctionIdentifier(names: Seq[String]): Option[FunctionIdentifier] = names match {
case Seq(name) => Some(FunctionIdentifier(name))
case Seq(database, name) => Some(FunctionIdentifier(name, Some(database)))
case _ => None
}
parts match {
case Seq(name)
if catalogManager.v1SessionCatalog.isRegisteredFunction(FunctionIdentifier(name)) =>
Some(FunctionIdentifier(name))
case CatalogAndMultipartIdentifier(None, names)
if CatalogV2Util.isSessionCatalog(currentCatalog) =>
namesToFunctionIdentifier(names)
case CatalogAndMultipartIdentifier(Some(catalog), names)
if CatalogV2Util.isSessionCatalog(catalog) =>
namesToFunctionIdentifier(names)
case _ => None
}
}
}
def parseSessionCatalogFunctionIdentifier(nameParts: Seq[String]): FunctionIdentifier = {
if (nameParts.length == 1 && catalogManager.v1SessionCatalog.isTempFunction(nameParts.head)) {
return FunctionIdentifier(nameParts.head)

View file

@ -617,12 +617,6 @@ private[spark] object QueryCompilationErrors {
s"the function '$func', please make sure it is on the classpath")
}
def v2CatalogNotSupportFunctionError(
catalog: String, namespace: Seq[String]): Throwable = {
new AnalysisException("V2 catalog does not support functions yet. " +
s"catalog: $catalog, namespace: '${namespace.quoted}'")
}
def resourceTypeNotSupportedError(resourceType: String): Throwable = {
new AnalysisException(s"Resource Type '$resourceType' is not supported.")
}

View file

@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog,
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
class LookupFunctionsSuite extends PlanTest {
@ -49,7 +50,7 @@ class LookupFunctionsSuite extends PlanTest {
assert(externalCatalog.getFunctionExistsCalledTimes == 1)
assert(analyzer.LookupFunctions.normalizeFuncName
(unresolvedPersistentFunc.name).database == Some("default"))
(unresolvedPersistentFunc.nameParts.asFunctionIdentifier).database == Some("default"))
}
test("SPARK-23486: the functionExists for the Registered function check") {
@ -70,9 +71,9 @@ class LookupFunctionsSuite extends PlanTest {
table("TaBlE"))
analyzer.LookupFunctions.apply(plan)
assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)
assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 4)
assert(analyzer.LookupFunctions.normalizeFuncName
(unresolvedRegisteredFunc.name).database == Some("default"))
(unresolvedRegisteredFunc.nameParts.asFunctionIdentifier).database == Some("default"))
}
}

View file

@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2SessionCatalog, NoSuchNamespaceException}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog => V1InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
class CatalogManagerSuite extends SparkFunSuite with SQLHelper {
private def createSessionCatalog(): SessionCatalog = {
val catalog = new InMemoryCatalog()
val catalog = new V1InMemoryCatalog()
catalog.createDatabase(
CatalogDatabase(SessionCatalog.DEFAULT_DATABASE, "", new URI("fake"), Map.empty),
ignoreIfExists = true)

View file

@ -24,14 +24,15 @@ import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.expressions.LogicalExpressions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class TableCatalogSuite extends SparkFunSuite {
class CatalogSuite extends SparkFunSuite {
import CatalogV2Implicits._
private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
@ -39,8 +40,8 @@ class TableCatalogSuite extends SparkFunSuite {
.add("id", IntegerType)
.add("data", StringType)
private def newCatalog(): TableCatalog with SupportsNamespaces = {
val newCatalog = new InMemoryTableCatalog
private def newCatalog(): InMemoryCatalog = {
val newCatalog = new InMemoryCatalog
newCatalog.initialize("test", CaseInsensitiveStringMap.empty())
newCatalog
}
@ -927,4 +928,43 @@ class TableCatalogSuite extends SparkFunSuite {
assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 2)
assert(partTable.rows.isEmpty)
}
val function: UnboundFunction = new UnboundFunction {
override def bind(inputType: StructType): BoundFunction = new ScalarFunction[Int] {
override def inputTypes(): Array[DataType] = Array(IntegerType)
override def resultType(): DataType = IntegerType
override def name(): String = "my_bound_function"
}
override def description(): String = "my_function"
override def name(): String = "my_function"
}
test("list functions") {
val catalog = newCatalog()
val ident1 = Identifier.of(Array("ns1", "ns2"), "func1")
val ident2 = Identifier.of(Array("ns1", "ns2"), "func2")
val ident3 = Identifier.of(Array("ns1", "ns3"), "func3")
catalog.createNamespace(Array("ns1", "ns2"), emptyProps)
catalog.createNamespace(Array("ns1", "ns3"), emptyProps)
catalog.createFunction(ident1, function)
catalog.createFunction(ident2, function)
catalog.createFunction(ident3, function)
assert(catalog.listFunctions(Array("ns1", "ns2")).toSet === Set(ident1, ident2))
assert(catalog.listFunctions(Array("ns1", "ns3")).toSet === Set(ident3))
assert(catalog.listFunctions(Array("ns1")).toSet == Set())
intercept[NoSuchNamespaceException](catalog.listFunctions(Array("ns2")))
}
test("lookup function") {
val catalog = newCatalog()
val ident = Identifier.of(Array("ns"), "func")
catalog.createNamespace(Array("ns"), emptyProps)
catalog.createFunction(ident, function)
assert(catalog.loadFunction(ident) == function)
intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns"), "func1")))
intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns1"), "func")))
}
}

View file

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.catalog
import java.util
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, NoSuchNamespaceException}
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog {
protected val functions: util.Map[Identifier, UnboundFunction] =
new ConcurrentHashMap[Identifier, UnboundFunction]()
override protected def allNamespaces: Seq[Seq[String]] = {
(tables.keySet.asScala.map(_.namespace.toSeq) ++
functions.keySet.asScala.map(_.namespace.toSeq) ++
namespaces.keySet.asScala).toSeq.distinct
}
override def listFunctions(namespace: Array[String]): Array[Identifier] = {
if (namespace.isEmpty || namespaceExists(namespace)) {
functions.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
} else {
throw new NoSuchNamespaceException(namespace)
}
}
override def loadFunction(ident: Identifier): UnboundFunction = {
Option(functions.get(ident)) match {
case Some(func) =>
func
case _ =>
throw new NoSuchFunctionException(ident)
}
}
def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction = {
functions.put(ident, fn)
}
}

View file

@ -138,7 +138,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
}
class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces {
private def allNamespaces: Seq[Seq[String]] = {
protected def allNamespaces: Seq[Seq[String]] = {
(tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct
}

View file

@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql.connector.catalog.functions;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.AggregateFunction;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.StructType;
import java.io.Serializable;
public class JavaAverage implements UnboundFunction {
@Override
public String name() {
return "avg";
}
@Override
public BoundFunction bind(StructType inputType) {
if (inputType.fields().length != 1) {
throw new UnsupportedOperationException("Expect exactly one argument");
}
if (inputType.fields()[0].dataType() instanceof DoubleType) {
return new JavaDoubleAverage();
}
throw new UnsupportedOperationException("Unsupported non-integral type: " +
inputType.fields()[0].dataType());
}
@Override
public String description() {
return null;
}
public static class JavaDoubleAverage implements AggregateFunction<State<Double>, Double> {
@Override
public State<Double> newAggregationState() {
return new State<>(0.0, 0.0);
}
@Override
public State<Double> update(State<Double> state, InternalRow input) {
if (input.isNullAt(0)) {
return state;
}
return new State<>(state.sum + input.getDouble(0), state.count + 1);
}
@Override
public Double produceResult(State<Double> state) {
return state.sum / state.count;
}
@Override
public State<Double> merge(State<Double> leftState, State<Double> rightState) {
return new State<>(leftState.sum + rightState.sum, leftState.count + rightState.count);
}
@Override
public DataType[] inputTypes() {
return new DataType[] { DataTypes.DoubleType };
}
@Override
public DataType resultType() {
return DataTypes.DoubleType;
}
@Override
public String name() {
return "davg";
}
}
public static class State<T> implements Serializable {
T sum, count;
State(T sum, T count) {
this.sum = sum;
this.count = count;
}
}
}

View file

@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql.connector.catalog.functions;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
public class JavaStrLen implements UnboundFunction {
private final BoundFunction fn;
public JavaStrLen(BoundFunction fn) {
this.fn = fn;
}
@Override
public String name() {
return "strlen";
}
@Override
public BoundFunction bind(StructType inputType) {
if (inputType.fields().length != 1) {
throw new UnsupportedOperationException("Expect exactly one argument");
}
if (inputType.fields()[0].dataType() instanceof StringType) {
return fn;
}
throw new UnsupportedOperationException("Except StringType");
}
@Override
public String description() {
return "strlen: returns the length of the input string\n" +
" strlen(string) -> int";
}
public static class JavaStrLenDefault implements ScalarFunction<Integer> {
@Override
public DataType[] inputTypes() {
return new DataType[] { DataTypes.StringType };
}
@Override
public DataType resultType() {
return DataTypes.IntegerType;
}
@Override
public String name() {
return "strlen";
}
@Override
public Integer produceResult(InternalRow input) {
String str = input.getString(0);
return str.length();
}
}
public static class JavaStrLenMagic implements ScalarFunction<Integer> {
@Override
public DataType[] inputTypes() {
return new DataType[] { DataTypes.StringType };
}
@Override
public DataType resultType() {
return DataTypes.IntegerType;
}
@Override
public String name() {
return "strlen";
}
public int invoke(UTF8String str) {
return str.toString().length();
}
}
public static class JavaStrLenNoImpl implements ScalarFunction<Integer> {
@Override
public DataType[] inputTypes() {
return new DataType[] { DataTypes.StringType };
}
@Override
public DataType resultType() {
return DataTypes.IntegerType;
}
@Override
public String name() {
return "strlen";
}
}
}

View file

@ -0,0 +1,420 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector
import java.util
import java.util.Collections
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaStrLen}
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = {
catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn)
}
test("undefined function") {
assert(intercept[AnalysisException](
sql("SELECT testcat.non_exist('abc')").collect()
).getMessage.contains("Undefined function"))
}
test("non-function catalog") {
withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) {
assert(intercept[AnalysisException](
sql("SELECT testcat.strlen('abc')").collect()
).getMessage.contains("is not a FunctionCatalog"))
}
}
test("built-in with non-function catalog should still work") {
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat",
"spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) {
checkAnswer(sql("SELECT length('abc')"), Row(3))
}
}
test("built-in with default v2 function catalog") {
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
checkAnswer(sql("SELECT length('abc')"), Row(3))
}
}
test("looking up higher-order function with non-session catalog") {
checkAnswer(sql("SELECT transform(array(1, 2, 3), x -> x + 1)"),
Row(Array(2, 3, 4)) :: Nil)
}
test("built-in override with default v2 function catalog") {
// a built-in function with the same name should take higher priority
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
checkAnswer(sql("SELECT length('abc')"), Row(3))
}
}
test("built-in override with non-session catalog") {
addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
checkAnswer(sql("SELECT length('abc')"), Row(3))
}
test("temp function override with default v2 function catalog") {
val className = "test.org.apache.spark.sql.JavaStringLength"
sql(s"CREATE FUNCTION length AS '$className'")
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
checkAnswer(sql("SELECT length('abc')"), Row(3))
}
}
test("view should use captured catalog and namespace for function lookup") {
val viewName = "my_view"
withView(viewName) {
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "my_avg"), IntegralAverage)
sql("USE ns")
sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT my_avg(col1) FROM values (1), (2), (3)")
}
// change default catalog and namespace and add a function with the same name but with no
// implementation
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat2") {
catalog("testcat2").asInstanceOf[SupportsNamespaces]
.createNamespace(Array("ns2"), emptyProps)
addFunction(Identifier.of(Array("ns2"), "my_avg"), NoImplAverage)
sql("USE ns2")
checkAnswer(sql(s"SELECT * FROM $viewName"), Row(2.0) :: Nil)
}
}
}
test("scalar function: with default produceResult method") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
}
test("scalar function: with default produceResult method w/ expression") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil)
}
test("scalar function: lookup magic method") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic))
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
}
test("scalar function: lookup magic method w/ expression") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic))
checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil)
}
test("scalar function: bad magic method") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagic))
assert(intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect())
.getMessage.contains("Cannot find a compatible"))
}
test("scalar function: bad magic method with default impl") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagicWithDefault))
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
}
test("scalar function: no implementation found") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenNoImpl))
intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect())
}
test("scalar function: invalid parameter type or length") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen(42)"))
.getMessage.contains("Expect StringType"))
assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('a', 'b')"))
.getMessage.contains("Expect exactly one argument"))
}
test("scalar function: default produceResult in Java") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"),
new JavaStrLen(new JavaStrLenDefault))
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
}
test("scalar function: magic method in Java") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"),
new JavaStrLen(new JavaStrLenMagic))
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
}
test("scalar function: no implementation found in Java") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"),
new JavaStrLen(new JavaStrLenNoImpl))
assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect())
.getMessage.contains("neither implement magic method nor override 'produceResult'"))
}
test("bad bound function (neither scalar nor aggregate)") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction))
assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')"))
.getMessage.contains("does not implement ScalarFunction or AggregateFunction"))
}
test("aggregate function: lookup int average") {
import testImplicits._
val t = "testcat.ns.t"
withTable(t) {
addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
(1 to 100).toDF("i").write.saveAsTable(t)
checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil)
}
}
test("aggregate function: lookup long average") {
import testImplicits._
val t = "testcat.ns.t"
withTable(t) {
addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
(1L to 100L).toDF("i").write.saveAsTable(t)
checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil)
}
}
test("aggregate function: lookup double average in Java") {
import testImplicits._
val t = "testcat.ns.t"
withTable(t) {
addFunction(Identifier.of(Array("ns"), "avg"), new JavaAverage)
Seq(1.toDouble, 2.toDouble, 3.toDouble).toDF("i").write.saveAsTable(t)
checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(2.0) :: Nil)
}
}
test("aggregate function: lookup int average w/ expression") {
import testImplicits._
val t = "testcat.ns.t"
withTable(t) {
addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
(1 to 100).toDF("i").write.saveAsTable(t)
checkAnswer(sql(s"SELECT testcat.ns.avg(i * 10) from $t"), Row(505) :: Nil)
}
}
test("aggregate function: unsupported input type") {
import testImplicits._
val t = "testcat.ns.t"
withTable(t) {
addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
Seq(1.toShort, 2.toShort).toDF("i").write.saveAsTable(t)
assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t"))
.getMessage.contains("Unsupported non-integral type: ShortType"))
}
}
private case class StrLen(impl: BoundFunction) extends UnboundFunction {
override def description(): String =
"""strlen: returns the length of the input string
| strlen(string) -> int""".stripMargin
override def name(): String = "strlen"
override def bind(inputType: StructType): BoundFunction = {
if (inputType.fields.length != 1) {
throw new UnsupportedOperationException("Expect exactly one argument");
}
inputType.fields(0).dataType match {
case StringType => impl
case _ =>
throw new UnsupportedOperationException("Expect StringType")
}
}
}
private case object StrLenDefault extends ScalarFunction[Int] {
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
override def name(): String = "strlen_default"
override def produceResult(input: InternalRow): Int = {
val s = input.getString(0)
s.length
}
}
private case object StrLenMagic extends ScalarFunction[Int] {
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
override def name(): String = "strlen_magic"
def invoke(input: UTF8String): Int = {
input.toString.length
}
}
private case object StrLenBadMagic extends ScalarFunction[Int] {
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
override def name(): String = "strlen_bad_magic"
def invoke(input: String): Int = {
input.length
}
}
private case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
override def name(): String = "strlen_bad_magic"
def invoke(input: String): Int = {
input.length
}
override def produceResult(input: InternalRow): Int = {
val s = input.getString(0)
s.length
}
}
private case object StrLenNoImpl extends ScalarFunction[Int] {
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
override def name(): String = "strlen_noimpl"
}
private case object BadBoundFunction extends BoundFunction {
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
override def name(): String = "bad_bound_func"
}
object IntegralAverage extends UnboundFunction {
override def name(): String = "iavg"
override def bind(inputType: StructType): BoundFunction = {
if (inputType.fields.length > 1) {
throw new UnsupportedOperationException("Too many arguments")
}
inputType.fields(0).dataType match {
case _: IntegerType => IntAverage
case _: LongType => LongAverage
case dataType =>
throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType")
}
}
override def description(): String =
"""iavg: produces an average using integer division, ignoring nulls
| iavg(int) -> int
| iavg(bigint) -> bigint""".stripMargin
}
object IntAverage extends AggregateFunction[(Int, Int), Int] {
override def name(): String = "iavg"
override def inputTypes(): Array[DataType] = Array(IntegerType)
override def resultType(): DataType = IntegerType
override def newAggregationState(): (Int, Int) = (0, 0)
override def update(state: (Int, Int), input: InternalRow): (Int, Int) = {
if (input.isNullAt(0)) {
state
} else {
val i = input.getInt(0)
state match {
case (_, 0) =>
(i, 1)
case (total, count) =>
(total + i, count + 1)
}
}
}
override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = {
(leftState._1 + rightState._1, leftState._2 + rightState._2)
}
override def produceResult(state: (Int, Int)): Int = state._1 / state._2
}
object LongAverage extends AggregateFunction[(Long, Long), Long] {
override def name(): String = "iavg"
override def inputTypes(): Array[DataType] = Array(LongType)
override def resultType(): DataType = LongType
override def newAggregationState(): (Long, Long) = (0L, 0L)
override def update(state: (Long, Long), input: InternalRow): (Long, Long) = {
if (input.isNullAt(0)) {
state
} else {
val l = input.getLong(0)
state match {
case (_, 0L) =>
(l, 1)
case (total, count) =>
(total + l, count + 1L)
}
}
}
override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = {
(leftState._1 + rightState._1, leftState._2 + rightState._2)
}
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}
object NoImplAverage extends UnboundFunction {
override def name(): String = "no_impl_avg"
override def description(): String = name()
override def bind(inputType: StructType): BoundFunction = {
throw new UnsupportedOperationException(s"Not implemented")
}
}
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.connector
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryPartitionTableCatalog, InMemoryTableCatalog, StagingInMemoryTableCatalog}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryCatalog, InMemoryPartitionTableCatalog, StagingInMemoryTableCatalog}
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SharedSparkSession
@ -32,11 +32,11 @@ trait DatasourceV2SQLBase
}
before {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
spark.conf.set("spark.sql.catalog.testpart", classOf[InMemoryPartitionTableCatalog].getName)
spark.conf.set(
"spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryTableCatalog].getName)
spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName)
spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryCatalog].getName)
spark.conf.set(
V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName)