[SPARK-34981][SQL] Implement V2 function resolution and evaluation
Co-Authored-By: Chao Sun <sunchaoapple.com> Co-Authored-By: Ryan Blue <rbluenetflix.com> ### What changes were proposed in this pull request? This implements function resolution and evaluation for functions registered through V2 FunctionCatalog [SPARK-27658](https://issues.apache.org/jira/browse/SPARK-27658). In particular: - Added documentation for how to define the "magic method" in `ScalarFunction`. - Added a new expression `ApplyFunctionExpression` which evaluates input by delegating to `ScalarFunction.produceResult` method. - added a new expression `V2Aggregator` which is a type of `TypedImperativeAggregate`. It's a wrapper of V2 `AggregateFunction` and mostly delegate methods to the implementation of the latter. It also uses plain Java serde for intermediate state. - Added function resolution logic for `ScalarFunction` and `AggregateFunction` in `Analyzer`. + For `ScalarFunction` this checks if the magic method is implemented through Java reflection, and create a `Invoke` expression if so. Otherwise, it checks if the default `produceResult` is overridden. If so, it creates a `ApplyFunctionExpression` which evaluates through `InternalRow`. Otherwise an analysis exception is thrown. + For `AggregateFunction`, this checks if the `update` method is overridden. If so, it converts it to `V2Aggregator`. Otherwise an analysis exception is thrown similar to the case of `ScalarFunction`. - Extended existing `InMemoryTableCatalog` to add the function catalog capability. Also renamed it to `InMemoryCatalog` since it no longer only covers tables. **Note**: this currently can successfully detect whether a subclass overrides the default `produceResult` or `update` method from the parent interface **only for Java implementations**. It seems in Scala it's hard to differentiate whether a subclass overrides a default method from its parent interface. In this case, it will be a runtime error instead of analysis error. A few TODOs: - Extend `V2SessionCatalog` with function catalog. This seems a little tricky since API such V2 `FunctionCatalog`'s `loadFunction` is different from V1 `SessionCatalog`'s `lookupFunction`. - Add magic method for `AggregateFunction`. - Type coercion when looking up functions ### Why are the changes needed? As V2 FunctionCatalog APIs are finalized, we should integrate it with function resolution and evaluation process so that they are actually useful. ### Does this PR introduce _any_ user-facing change? Yes, now a function exposed through V2 FunctionCatalog can be analyzed and evaluated. ### How was this patch tested? Added new unit tests. Closes #32082 from sunchao/resolve-func-v2. Lead-authored-by: Chao Sun <sunchao@apple.com> Co-authored-by: Chao Sun <sunchao@apache.org> Co-authored-by: Chao Sun <sunchao@uber.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
0bcf348438
commit
86d3bb5f7d
|
@ -25,13 +25,12 @@ import java.io.Serializable;
|
|||
/**
|
||||
* Interface for a function that produces a result value by aggregating over multiple input rows.
|
||||
* <p>
|
||||
* For each input row, Spark will call an update method that corresponds to the
|
||||
* {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
|
||||
* Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
|
||||
* update with {@link InternalRow}.
|
||||
* <p>
|
||||
* The JVM type of result values produced by this function must be the type used by Spark's
|
||||
* For each input row, Spark will call the {@link #update} method which should evaluate the row
|
||||
* and update the aggregation state. The JVM type of result values produced by
|
||||
* {@link #produceResult} must be the type used by Spark's
|
||||
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
|
||||
* Please refer to class documentation of {@link ScalarFunction} for the mapping between
|
||||
* {@link DataType} and the JVM type.
|
||||
* <p>
|
||||
* All implementations must support partial aggregation by implementing merge so that Spark can
|
||||
* partially aggregate and shuffle intermediate results, instead of shuffling all rows for an
|
||||
|
@ -68,9 +67,7 @@ public interface AggregateFunction<S extends Serializable, R> extends BoundFunct
|
|||
* @param input an input row
|
||||
* @return updated aggregation state
|
||||
*/
|
||||
default S update(S state, InternalRow input) {
|
||||
throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update");
|
||||
}
|
||||
S update(S state, InternalRow input);
|
||||
|
||||
/**
|
||||
* Merge two partial aggregation states.
|
||||
|
|
|
@ -23,17 +23,67 @@ import org.apache.spark.sql.types.DataType;
|
|||
/**
|
||||
* Interface for a function that produces a result value for each input row.
|
||||
* <p>
|
||||
* For each input row, Spark will call a produceResult method that corresponds to the
|
||||
* {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
|
||||
* Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
|
||||
* {@link #produceResult(InternalRow)}.
|
||||
* To evaluate each input row, Spark will first try to lookup and use a "magic method" (described
|
||||
* below) through Java reflection. If the method is not found, Spark will call
|
||||
* {@link #produceResult(InternalRow)} as a fallback approach.
|
||||
* <p>
|
||||
* The JVM type of result values produced by this function must be the type used by Spark's
|
||||
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
|
||||
* <p>
|
||||
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws
|
||||
* {@link UnsupportedOperationException}. Users can choose to override this method, or implement
|
||||
* a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters
|
||||
* instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java
|
||||
* reflection and will also provide better performance in general, due to optimizations such as
|
||||
* codegen, removal of Java boxing, etc.
|
||||
*
|
||||
* For example, a scalar UDF for adding two integers can be defined as follow with the magic
|
||||
* method approach:
|
||||
*
|
||||
* <pre>
|
||||
* public class IntegerAdd implements{@code ScalarFunction<Integer>} {
|
||||
* public int invoke(int left, int right) {
|
||||
* return left + right;
|
||||
* }
|
||||
* }
|
||||
* </pre>
|
||||
* In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over
|
||||
* {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by
|
||||
* first converting the actual input SQL data types to their corresponding Java types following
|
||||
* the mapping defined below, and then checking if there is a matching method from all the
|
||||
* declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and
|
||||
* the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}.
|
||||
* <p>
|
||||
* The following are the mapping from {@link DataType SQL data type} to Java type through
|
||||
* the magic method approach:
|
||||
* <ul>
|
||||
* <li>{@link org.apache.spark.sql.types.BooleanType}: {@code boolean}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.ByteType}: {@code byte}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.ShortType}: {@code short}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.IntegerType}: {@code int}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.LongType}: {@code long}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.FloatType}: {@code float}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.DoubleType}: {@code double}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.StringType}:
|
||||
* {@link org.apache.spark.unsafe.types.UTF8String}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.DateType}: {@code int}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.TimestampType}: {@code long}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.DayTimeIntervalType}: {@code long}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.YearMonthIntervalType}: {@code int}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.DecimalType}:
|
||||
* {@link org.apache.spark.sql.types.Decimal}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.StructType}: {@link InternalRow}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.ArrayType}:
|
||||
* {@link org.apache.spark.sql.catalyst.util.ArrayData}</li>
|
||||
* <li>{@link org.apache.spark.sql.types.MapType}:
|
||||
* {@link org.apache.spark.sql.catalyst.util.MapData}</li>
|
||||
* </ul>
|
||||
*
|
||||
* @param <R> the JVM type of result values
|
||||
*/
|
||||
public interface ScalarFunction<R> extends BoundFunction {
|
||||
String MAGIC_METHOD_NAME = "invoke";
|
||||
|
||||
/**
|
||||
* Applies the function to an input row to produce a value.
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import java.lang.reflect.Method
|
||||
import java.util
|
||||
import java.util.Locale
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
@ -29,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
|
|||
import org.apache.spark.sql.catalyst._
|
||||
import org.apache.spark.sql.catalyst.catalog._
|
||||
import org.apache.spark.sql.catalyst.encoders.OuterScopes
|
||||
import org.apache.spark.sql.catalyst.expressions.{FrameLessOffsetWindowFunction, _}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _}
|
||||
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.expressions.objects._
|
||||
|
@ -44,6 +45,8 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
|
|||
import org.apache.spark.sql.connector.catalog._
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
|
||||
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction}
|
||||
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
|
||||
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
|
||||
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
|
||||
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
|
||||
|
@ -281,7 +284,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
ResolveAggregateFunctions ::
|
||||
TimeWindowing ::
|
||||
ResolveInlineTables ::
|
||||
ResolveHigherOrderFunctions(v1SessionCatalog) ::
|
||||
ResolveHigherOrderFunctions(catalogManager) ::
|
||||
ResolveLambdaVariables ::
|
||||
ResolveTimeZone ::
|
||||
ResolveRandomSeed ::
|
||||
|
@ -895,9 +898,10 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
}
|
||||
}
|
||||
|
||||
// If we are resolving relations insides views, we need to expand single-part relation names with
|
||||
// the current catalog and namespace of when the view was created.
|
||||
private def expandRelationName(nameParts: Seq[String]): Seq[String] = {
|
||||
// If we are resolving database objects (relations, functions, etc.) insides views, we may need to
|
||||
// expand single or multi-part identifiers with the current catalog and namespace of when the
|
||||
// view was created.
|
||||
private def expandIdentifier(nameParts: Seq[String]): Seq[String] = {
|
||||
if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts
|
||||
|
||||
if (nameParts.length == 1) {
|
||||
|
@ -1040,7 +1044,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
identifier: Seq[String],
|
||||
options: CaseInsensitiveStringMap,
|
||||
isStreaming: Boolean): Option[LogicalPlan] =
|
||||
expandRelationName(identifier) match {
|
||||
expandIdentifier(identifier) match {
|
||||
case NonSessionCatalogAndIdentifier(catalog, ident) =>
|
||||
CatalogV2Util.loadTable(catalog, ident) match {
|
||||
case Some(table) =>
|
||||
|
@ -1153,7 +1157,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
}
|
||||
|
||||
private def lookupTableOrView(identifier: Seq[String]): Option[LogicalPlan] = {
|
||||
expandRelationName(identifier) match {
|
||||
expandIdentifier(identifier) match {
|
||||
case SessionCatalogAndIdentifier(catalog, ident) =>
|
||||
CatalogV2Util.loadTable(catalog, ident).map {
|
||||
case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW =>
|
||||
|
@ -1173,7 +1177,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
identifier: Seq[String],
|
||||
options: CaseInsensitiveStringMap,
|
||||
isStreaming: Boolean): Option[LogicalPlan] = {
|
||||
expandRelationName(identifier) match {
|
||||
expandIdentifier(identifier) match {
|
||||
case SessionCatalogAndIdentifier(catalog, ident) =>
|
||||
lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map {
|
||||
case v1Table: V1Table =>
|
||||
|
@ -1569,8 +1573,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
// results and confuse users if there is any null values. For count(t1.*, t2.*), it is
|
||||
// still allowed, since it's well-defined in spark.
|
||||
if (!conf.allowStarWithSingleTableIdentifierInCount &&
|
||||
f1.name.database.isEmpty &&
|
||||
f1.name.funcName == "count" &&
|
||||
f1.nameParts == Seq("count") &&
|
||||
f1.arguments.length == 1) {
|
||||
f1.arguments.foreach {
|
||||
case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) =>
|
||||
|
@ -1958,17 +1961,19 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
override def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]()
|
||||
plan.resolveExpressions {
|
||||
case f: UnresolvedFunction
|
||||
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
|
||||
case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
|
||||
case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
|
||||
externalFunctionNameSet.add(normalizeFuncName(f.name))
|
||||
case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) =>
|
||||
if (externalFunctionNameSet.contains(normalizeFuncName(ident)) ||
|
||||
v1SessionCatalog.isRegisteredFunction(ident)) {
|
||||
f
|
||||
case f: UnresolvedFunction =>
|
||||
} else if (v1SessionCatalog.isPersistentFunction(ident)) {
|
||||
externalFunctionNameSet.add(normalizeFuncName(ident))
|
||||
f
|
||||
} else {
|
||||
withPosition(f) {
|
||||
throw new NoSuchFunctionException(
|
||||
f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
|
||||
f.name.funcName)
|
||||
ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
|
||||
ident.funcName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2016,9 +2021,10 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
name, other.getClass.getCanonicalName)
|
||||
}
|
||||
}
|
||||
case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter, ignoreNulls) =>
|
||||
withPosition(u) {
|
||||
v1SessionCatalog.lookupFunction(funcId, arguments) match {
|
||||
|
||||
case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, isDistinct, filter,
|
||||
ignoreNulls) => withPosition(u) {
|
||||
v1SessionCatalog.lookupFunction(ident, arguments) match {
|
||||
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
|
||||
// the context of a Window clause. They do not need to be wrapped in an
|
||||
// AggregateExpression.
|
||||
|
@ -2096,6 +2102,120 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
other
|
||||
}
|
||||
}
|
||||
|
||||
case u @ UnresolvedFunction(nameParts, arguments, isDistinct, filter, ignoreNulls) =>
|
||||
withPosition(u) {
|
||||
expandIdentifier(nameParts) match {
|
||||
case NonSessionCatalogAndIdentifier(catalog, ident) =>
|
||||
if (!catalog.isFunctionCatalog) {
|
||||
throw new AnalysisException(s"Trying to lookup function '$ident' in " +
|
||||
s"catalog '${catalog.name()}', but it is not a FunctionCatalog.")
|
||||
}
|
||||
|
||||
val unbound = catalog.asFunctionCatalog.loadFunction(ident)
|
||||
val inputType = StructType(arguments.zipWithIndex.map {
|
||||
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
|
||||
})
|
||||
val bound = try {
|
||||
unbound.bind(inputType)
|
||||
} catch {
|
||||
case unsupported: UnsupportedOperationException =>
|
||||
throw new AnalysisException(s"Function '${unbound.name}' cannot process " +
|
||||
s"input: (${arguments.map(_.dataType.simpleString).mkString(", ")}): " +
|
||||
unsupported.getMessage, cause = Some(unsupported))
|
||||
}
|
||||
|
||||
bound match {
|
||||
case scalarFunc: ScalarFunction[_] =>
|
||||
processV2ScalarFunction(scalarFunc, inputType, arguments, isDistinct,
|
||||
filter, ignoreNulls)
|
||||
case aggFunc: V2AggregateFunction[_, _] =>
|
||||
processV2AggregateFunction(aggFunc, arguments, isDistinct, filter,
|
||||
ignoreNulls)
|
||||
case _ =>
|
||||
failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction" +
|
||||
s" or AggregateFunction")
|
||||
}
|
||||
|
||||
case _ => u
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def processV2ScalarFunction(
|
||||
scalarFunc: ScalarFunction[_],
|
||||
inputType: StructType,
|
||||
arguments: Seq[Expression],
|
||||
isDistinct: Boolean,
|
||||
filter: Option[Expression],
|
||||
ignoreNulls: Boolean): Expression = {
|
||||
if (isDistinct) {
|
||||
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
|
||||
scalarFunc.name(), "DISTINCT")
|
||||
} else if (filter.isDefined) {
|
||||
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
|
||||
scalarFunc.name(), "FILTER clause")
|
||||
} else if (ignoreNulls) {
|
||||
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
|
||||
scalarFunc.name(), "IGNORE NULLS")
|
||||
} else {
|
||||
// TODO: implement type coercion by looking at input type from the UDF. We
|
||||
// may also want to check if the parameter types from the magic method
|
||||
// match the input type through `BoundFunction.inputTypes`.
|
||||
val argClasses = inputType.fields.map(_.dataType)
|
||||
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
|
||||
case Some(_) =>
|
||||
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
|
||||
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
|
||||
arguments, returnNullable = scalarFunc.isResultNullable)
|
||||
case _ =>
|
||||
// TODO: handle functions defined in Scala too - in Scala, even if a
|
||||
// subclass do not override the default method in parent interface
|
||||
// defined in Java, the method can still be found from
|
||||
// `getDeclaredMethod`.
|
||||
// since `inputType` is a `StructType`, it is mapped to a `InternalRow`
|
||||
// which we can use to lookup the `produceResult` method.
|
||||
findMethod(scalarFunc, "produceResult", Seq(inputType)) match {
|
||||
case Some(_) =>
|
||||
ApplyFunctionExpression(scalarFunc, arguments)
|
||||
case None =>
|
||||
failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" +
|
||||
s" magic method nor override 'produceResult'")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def processV2AggregateFunction(
|
||||
aggFunc: V2AggregateFunction[_, _],
|
||||
arguments: Seq[Expression],
|
||||
isDistinct: Boolean,
|
||||
filter: Option[Expression],
|
||||
ignoreNulls: Boolean): Expression = {
|
||||
if (ignoreNulls) {
|
||||
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
|
||||
aggFunc.name(), "IGNORE NULLS")
|
||||
}
|
||||
val aggregator = V2Aggregator(aggFunc, arguments)
|
||||
AggregateExpression(aggregator, Complete, isDistinct, filter)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the input `fn` implements the given `methodName` with parameter types specified
|
||||
* via `inputType`.
|
||||
*/
|
||||
private def findMethod(
|
||||
fn: BoundFunction,
|
||||
methodName: String,
|
||||
inputType: Seq[DataType]): Option[Method] = {
|
||||
val cls = fn.getClass
|
||||
try {
|
||||
val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass)
|
||||
Some(cls.getDeclaredMethod(methodName, argClasses: _*))
|
||||
} catch {
|
||||
case _: NoSuchMethodException =>
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,10 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog}
|
||||
import org.apache.spark.sql.errors.QueryCompilationErrors
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
|
@ -30,13 +30,14 @@ import org.apache.spark.sql.types.DataType
|
|||
* so we need to resolve higher order function when all children are either resolved or a lambda
|
||||
* function.
|
||||
*/
|
||||
case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] {
|
||||
case class ResolveHigherOrderFunctions(catalogManager: CatalogManager)
|
||||
extends Rule[LogicalPlan] with LookupCatalog {
|
||||
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
|
||||
case u @ UnresolvedFunction(fn, children, false, filter, ignoreNulls)
|
||||
case u @ UnresolvedFunction(AsFunctionIdentifier(ident), children, false, filter, ignoreNulls)
|
||||
if hasLambdaAndResolvedArguments(children) =>
|
||||
withPosition(u) {
|
||||
catalog.lookupFunction(fn, children) match {
|
||||
catalogManager.v1SessionCatalog.lookupFunction(ident, children) match {
|
||||
case func: HigherOrderFunction =>
|
||||
filter.foreach(_.failAnalysis("FILTER predicate specified, " +
|
||||
s"but ${func.prettyName} is not an aggregate function"))
|
||||
|
|
|
@ -269,12 +269,13 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio
|
|||
}
|
||||
|
||||
case class UnresolvedFunction(
|
||||
name: FunctionIdentifier,
|
||||
nameParts: Seq[String],
|
||||
arguments: Seq[Expression],
|
||||
isDistinct: Boolean,
|
||||
filter: Option[Expression] = None,
|
||||
ignoreNulls: Boolean = false)
|
||||
extends Expression with Unevaluable {
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
|
||||
override def children: Seq[Expression] = arguments ++ filter.toSeq
|
||||
|
||||
|
@ -282,10 +283,10 @@ case class UnresolvedFunction(
|
|||
override def nullable: Boolean = throw new UnresolvedException("nullable")
|
||||
override lazy val resolved = false
|
||||
|
||||
override def prettyName: String = name.unquotedString
|
||||
override def prettyName: String = nameParts.quoted
|
||||
override def toString: String = {
|
||||
val distinct = if (isDistinct) "distinct " else ""
|
||||
s"'$name($distinct${children.mkString(", ")})"
|
||||
s"'${nameParts.quoted}($distinct${children.mkString(", ")})"
|
||||
}
|
||||
|
||||
override protected def withNewChildrenInternal(
|
||||
|
@ -299,8 +300,17 @@ case class UnresolvedFunction(
|
|||
}
|
||||
|
||||
object UnresolvedFunction {
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
|
||||
def apply(
|
||||
name: FunctionIdentifier,
|
||||
arguments: Seq[Expression],
|
||||
isDistinct: Boolean): UnresolvedFunction = {
|
||||
UnresolvedFunction(name.asMultipart, arguments, isDistinct)
|
||||
}
|
||||
|
||||
def apply(name: String, arguments: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = {
|
||||
UnresolvedFunction(FunctionIdentifier(name, None), arguments, isDistinct)
|
||||
UnresolvedFunction(Seq(name), arguments, isDistinct)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1576,6 +1576,8 @@ class SessionCatalog(
|
|||
name: FunctionIdentifier,
|
||||
children: Seq[Expression],
|
||||
registry: FunctionRegistryBase[T]): T = synchronized {
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
|
||||
|
||||
// Note: the implementation of this function is a little bit convoluted.
|
||||
// We probably shouldn't use a single FunctionRegistry to register all three kinds of functions
|
||||
// (built-in, temp, and external).
|
||||
|
@ -1598,7 +1600,8 @@ class SessionCatalog(
|
|||
case Seq() => getCurrentDatabase
|
||||
case Seq(_, db) => db
|
||||
case Seq(catalog, namespace @ _*) =>
|
||||
throw QueryCompilationErrors.v2CatalogNotSupportFunctionError(catalog, namespace)
|
||||
throw new IllegalStateException(s"[BUG] unexpected v2 catalog: $catalog, and " +
|
||||
s"namespace: ${namespace.quoted} in v1 function lookup")
|
||||
}
|
||||
|
||||
// If the name itself is not qualified, add the current database to it.
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
case class ApplyFunctionExpression(
|
||||
function: ScalarFunction[_],
|
||||
children: Seq[Expression]) extends Expression with UserDefinedExpression with CodegenFallback {
|
||||
override def nullable: Boolean = function.isResultNullable
|
||||
override def name: String = function.name()
|
||||
override def dataType: DataType = function.resultType()
|
||||
|
||||
private lazy val reusedRow = new GenericInternalRow(children.size)
|
||||
|
||||
/** Returns the result of evaluating this expression on a given input Row */
|
||||
override def eval(input: InternalRow): Any = {
|
||||
children.zipWithIndex.foreach {
|
||||
case (expr, pos) =>
|
||||
reusedRow.update(pos, expr.eval(input))
|
||||
}
|
||||
|
||||
function.produceResult(reusedRow)
|
||||
}
|
||||
|
||||
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
|
||||
copy(children = newChildren)
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions.aggregate
|
||||
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection}
|
||||
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction}
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
case class V2Aggregator[BUF <: java.io.Serializable, OUT](
|
||||
aggrFunc: V2AggregateFunction[BUF, OUT],
|
||||
children: Seq[Expression],
|
||||
mutableAggBufferOffset: Int = 0,
|
||||
inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[BUF] {
|
||||
private[this] lazy val inputProjection = UnsafeProjection.create(children)
|
||||
|
||||
override def nullable: Boolean = aggrFunc.isResultNullable
|
||||
override def dataType: DataType = aggrFunc.resultType()
|
||||
override def createAggregationBuffer(): BUF = aggrFunc.newAggregationState()
|
||||
|
||||
override def update(buffer: BUF, input: InternalRow): BUF = {
|
||||
aggrFunc.update(buffer, inputProjection(input))
|
||||
}
|
||||
|
||||
override def merge(buffer: BUF, input: BUF): BUF = aggrFunc.merge(buffer, input)
|
||||
|
||||
override def eval(buffer: BUF): Any = {
|
||||
aggrFunc.produceResult(buffer)
|
||||
}
|
||||
|
||||
override def serialize(buffer: BUF): Array[Byte] = {
|
||||
val bos = new ByteArrayOutputStream()
|
||||
val out = new ObjectOutputStream(bos)
|
||||
out.writeObject(buffer)
|
||||
out.close()
|
||||
bos.toByteArray
|
||||
}
|
||||
|
||||
override def deserialize(bytes: Array[Byte]): BUF = {
|
||||
val in = new ObjectInputStream(new ByteArrayInputStream(bytes))
|
||||
in.readObject().asInstanceOf[BUF]
|
||||
}
|
||||
|
||||
def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): V2Aggregator[BUF, OUT] =
|
||||
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
|
||||
|
||||
def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): V2Aggregator[BUF, OUT] =
|
||||
copy(inputAggBufferOffset = newInputAggBufferOffset)
|
||||
|
||||
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
|
||||
copy(children = newChildren)
|
||||
}
|
||||
|
|
@ -1816,7 +1816,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
|
|||
val ignoreNulls =
|
||||
Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE).getOrElse(false)
|
||||
val function = UnresolvedFunction(
|
||||
getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter, ignoreNulls)
|
||||
getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls)
|
||||
|
||||
// Check if the function is evaluated in a windowed context.
|
||||
ctx.windowSpec match {
|
||||
|
@ -1828,15 +1828,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Create a function database (optional) and name pair, for multipartIdentifier.
|
||||
* This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS.
|
||||
*/
|
||||
protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = {
|
||||
visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a function database (optional) and name pair.
|
||||
*/
|
||||
|
@ -1867,6 +1858,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
|
|||
}
|
||||
}
|
||||
|
||||
protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = {
|
||||
if (ctx.qualifiedName != null) {
|
||||
ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq
|
||||
} else {
|
||||
Seq(ctx.getText)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an [[LambdaFunction]].
|
||||
*/
|
||||
|
|
|
@ -83,12 +83,36 @@ private[sql] object CatalogV2Implicits {
|
|||
throw new AnalysisException(
|
||||
s"Cannot use catalog ${plugin.name}: does not support namespaces")
|
||||
}
|
||||
|
||||
def isFunctionCatalog: Boolean = plugin match {
|
||||
case _: FunctionCatalog => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def asFunctionCatalog: FunctionCatalog = plugin match {
|
||||
case functionCatalog: FunctionCatalog =>
|
||||
functionCatalog
|
||||
case _ =>
|
||||
throw new AnalysisException(
|
||||
s"Cannot use catalog '${plugin.name}': not a FunctionCatalog")
|
||||
}
|
||||
}
|
||||
|
||||
implicit class NamespaceHelper(namespace: Array[String]) {
|
||||
def quoted: String = namespace.map(quoteIfNeeded).mkString(".")
|
||||
}
|
||||
|
||||
implicit class FunctionIdentifierHelper(ident: FunctionIdentifier) {
|
||||
def asMultipart: Seq[String] = {
|
||||
ident.database match {
|
||||
case Some(db) =>
|
||||
Seq(db, ident.funcName)
|
||||
case _ =>
|
||||
Seq(ident.funcName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
implicit class IdentifierHelper(ident: Identifier) {
|
||||
def quoted: String = {
|
||||
if (ident.namespace.nonEmpty) {
|
||||
|
@ -132,6 +156,14 @@ private[sql] object CatalogV2Implicits {
|
|||
s"$quoted is not a valid TableIdentifier as it has more than 2 name parts.")
|
||||
}
|
||||
|
||||
def asFunctionIdentifier: FunctionIdentifier = parts match {
|
||||
case Seq(funcName) => FunctionIdentifier(funcName)
|
||||
case Seq(dbName, funcName) => FunctionIdentifier(funcName, Some(dbName))
|
||||
case _ =>
|
||||
throw new AnalysisException(
|
||||
s"$quoted is not a valid FunctionIdentifier as it has more than 2 name parts.")
|
||||
}
|
||||
|
||||
def quoted: String = parts.map(quoteIfNeeded).mkString(".")
|
||||
}
|
||||
|
||||
|
|
|
@ -154,6 +154,28 @@ private[sql] trait LookupCatalog extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
object AsFunctionIdentifier {
|
||||
def unapply(parts: Seq[String]): Option[FunctionIdentifier] = {
|
||||
def namesToFunctionIdentifier(names: Seq[String]): Option[FunctionIdentifier] = names match {
|
||||
case Seq(name) => Some(FunctionIdentifier(name))
|
||||
case Seq(database, name) => Some(FunctionIdentifier(name, Some(database)))
|
||||
case _ => None
|
||||
}
|
||||
parts match {
|
||||
case Seq(name)
|
||||
if catalogManager.v1SessionCatalog.isRegisteredFunction(FunctionIdentifier(name)) =>
|
||||
Some(FunctionIdentifier(name))
|
||||
case CatalogAndMultipartIdentifier(None, names)
|
||||
if CatalogV2Util.isSessionCatalog(currentCatalog) =>
|
||||
namesToFunctionIdentifier(names)
|
||||
case CatalogAndMultipartIdentifier(Some(catalog), names)
|
||||
if CatalogV2Util.isSessionCatalog(catalog) =>
|
||||
namesToFunctionIdentifier(names)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def parseSessionCatalogFunctionIdentifier(nameParts: Seq[String]): FunctionIdentifier = {
|
||||
if (nameParts.length == 1 && catalogManager.v1SessionCatalog.isTempFunction(nameParts.head)) {
|
||||
return FunctionIdentifier(nameParts.head)
|
||||
|
|
|
@ -617,12 +617,6 @@ private[spark] object QueryCompilationErrors {
|
|||
s"the function '$func', please make sure it is on the classpath")
|
||||
}
|
||||
|
||||
def v2CatalogNotSupportFunctionError(
|
||||
catalog: String, namespace: Seq[String]): Throwable = {
|
||||
new AnalysisException("V2 catalog does not support functions yet. " +
|
||||
s"catalog: $catalog, namespace: '${namespace.quoted}'")
|
||||
}
|
||||
|
||||
def resourceTypeNotSupportedError(resourceType: String): Throwable = {
|
||||
new AnalysisException(s"Resource Type '$resourceType' is not supported.")
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog,
|
|||
import org.apache.spark.sql.catalyst.expressions.Alias
|
||||
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
|
||||
class LookupFunctionsSuite extends PlanTest {
|
||||
|
||||
|
@ -49,7 +50,7 @@ class LookupFunctionsSuite extends PlanTest {
|
|||
|
||||
assert(externalCatalog.getFunctionExistsCalledTimes == 1)
|
||||
assert(analyzer.LookupFunctions.normalizeFuncName
|
||||
(unresolvedPersistentFunc.name).database == Some("default"))
|
||||
(unresolvedPersistentFunc.nameParts.asFunctionIdentifier).database == Some("default"))
|
||||
}
|
||||
|
||||
test("SPARK-23486: the functionExists for the Registered function check") {
|
||||
|
@ -70,9 +71,9 @@ class LookupFunctionsSuite extends PlanTest {
|
|||
table("TaBlE"))
|
||||
analyzer.LookupFunctions.apply(plan)
|
||||
|
||||
assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)
|
||||
assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 4)
|
||||
assert(analyzer.LookupFunctions.normalizeFuncName
|
||||
(unresolvedRegisteredFunc.name).database == Some("default"))
|
||||
(unresolvedRegisteredFunc.nameParts.asFunctionIdentifier).database == Some("default"))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2SessionCatalog, NoSuchNamespaceException}
|
||||
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
|
||||
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog => V1InMemoryCatalog, SessionCatalog}
|
||||
import org.apache.spark.sql.catalyst.plans.SQLHelper
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||
|
@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
|||
class CatalogManagerSuite extends SparkFunSuite with SQLHelper {
|
||||
|
||||
private def createSessionCatalog(): SessionCatalog = {
|
||||
val catalog = new InMemoryCatalog()
|
||||
val catalog = new V1InMemoryCatalog()
|
||||
catalog.createDatabase(
|
||||
CatalogDatabase(SessionCatalog.DEFAULT_DATABASE, "", new URI("fake"), Map.empty),
|
||||
ignoreIfExists = true)
|
||||
|
|
|
@ -24,14 +24,15 @@ import scala.collection.JavaConverters._
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
|
||||
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
|
||||
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
|
||||
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
|
||||
import org.apache.spark.sql.connector.expressions.LogicalExpressions
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
|
||||
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
|
||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||
|
||||
class TableCatalogSuite extends SparkFunSuite {
|
||||
class CatalogSuite extends SparkFunSuite {
|
||||
import CatalogV2Implicits._
|
||||
|
||||
private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
|
||||
|
@ -39,8 +40,8 @@ class TableCatalogSuite extends SparkFunSuite {
|
|||
.add("id", IntegerType)
|
||||
.add("data", StringType)
|
||||
|
||||
private def newCatalog(): TableCatalog with SupportsNamespaces = {
|
||||
val newCatalog = new InMemoryTableCatalog
|
||||
private def newCatalog(): InMemoryCatalog = {
|
||||
val newCatalog = new InMemoryCatalog
|
||||
newCatalog.initialize("test", CaseInsensitiveStringMap.empty())
|
||||
newCatalog
|
||||
}
|
||||
|
@ -927,4 +928,43 @@ class TableCatalogSuite extends SparkFunSuite {
|
|||
assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 2)
|
||||
assert(partTable.rows.isEmpty)
|
||||
}
|
||||
|
||||
val function: UnboundFunction = new UnboundFunction {
|
||||
override def bind(inputType: StructType): BoundFunction = new ScalarFunction[Int] {
|
||||
override def inputTypes(): Array[DataType] = Array(IntegerType)
|
||||
override def resultType(): DataType = IntegerType
|
||||
override def name(): String = "my_bound_function"
|
||||
}
|
||||
override def description(): String = "my_function"
|
||||
override def name(): String = "my_function"
|
||||
}
|
||||
|
||||
test("list functions") {
|
||||
val catalog = newCatalog()
|
||||
val ident1 = Identifier.of(Array("ns1", "ns2"), "func1")
|
||||
val ident2 = Identifier.of(Array("ns1", "ns2"), "func2")
|
||||
val ident3 = Identifier.of(Array("ns1", "ns3"), "func3")
|
||||
|
||||
catalog.createNamespace(Array("ns1", "ns2"), emptyProps)
|
||||
catalog.createNamespace(Array("ns1", "ns3"), emptyProps)
|
||||
catalog.createFunction(ident1, function)
|
||||
catalog.createFunction(ident2, function)
|
||||
catalog.createFunction(ident3, function)
|
||||
|
||||
assert(catalog.listFunctions(Array("ns1", "ns2")).toSet === Set(ident1, ident2))
|
||||
assert(catalog.listFunctions(Array("ns1", "ns3")).toSet === Set(ident3))
|
||||
assert(catalog.listFunctions(Array("ns1")).toSet == Set())
|
||||
intercept[NoSuchNamespaceException](catalog.listFunctions(Array("ns2")))
|
||||
}
|
||||
|
||||
test("lookup function") {
|
||||
val catalog = newCatalog()
|
||||
val ident = Identifier.of(Array("ns"), "func")
|
||||
catalog.createNamespace(Array("ns"), emptyProps)
|
||||
catalog.createFunction(ident, function)
|
||||
|
||||
assert(catalog.loadFunction(ident) == function)
|
||||
intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns"), "func1")))
|
||||
intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns1"), "func")))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.catalog
|
||||
|
||||
import java.util
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, NoSuchNamespaceException}
|
||||
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
|
||||
|
||||
class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog {
|
||||
protected val functions: util.Map[Identifier, UnboundFunction] =
|
||||
new ConcurrentHashMap[Identifier, UnboundFunction]()
|
||||
|
||||
override protected def allNamespaces: Seq[Seq[String]] = {
|
||||
(tables.keySet.asScala.map(_.namespace.toSeq) ++
|
||||
functions.keySet.asScala.map(_.namespace.toSeq) ++
|
||||
namespaces.keySet.asScala).toSeq.distinct
|
||||
}
|
||||
|
||||
override def listFunctions(namespace: Array[String]): Array[Identifier] = {
|
||||
if (namespace.isEmpty || namespaceExists(namespace)) {
|
||||
functions.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
|
||||
} else {
|
||||
throw new NoSuchNamespaceException(namespace)
|
||||
}
|
||||
}
|
||||
|
||||
override def loadFunction(ident: Identifier): UnboundFunction = {
|
||||
Option(functions.get(ident)) match {
|
||||
case Some(func) =>
|
||||
func
|
||||
case _ =>
|
||||
throw new NoSuchFunctionException(ident)
|
||||
}
|
||||
}
|
||||
|
||||
def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction = {
|
||||
functions.put(ident, fn)
|
||||
}
|
||||
}
|
|
@ -138,7 +138,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
|
|||
}
|
||||
|
||||
class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces {
|
||||
private def allNamespaces: Seq[Seq[String]] = {
|
||||
protected def allNamespaces: Seq[Seq[String]] = {
|
||||
(tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test.org.apache.spark.sql.connector.catalog.functions;
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow;
|
||||
import org.apache.spark.sql.connector.catalog.functions.AggregateFunction;
|
||||
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
|
||||
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
|
||||
import org.apache.spark.sql.types.DataType;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.DoubleType;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
public class JavaAverage implements UnboundFunction {
|
||||
@Override
|
||||
public String name() {
|
||||
return "avg";
|
||||
}
|
||||
|
||||
@Override
|
||||
public BoundFunction bind(StructType inputType) {
|
||||
if (inputType.fields().length != 1) {
|
||||
throw new UnsupportedOperationException("Expect exactly one argument");
|
||||
}
|
||||
if (inputType.fields()[0].dataType() instanceof DoubleType) {
|
||||
return new JavaDoubleAverage();
|
||||
}
|
||||
throw new UnsupportedOperationException("Unsupported non-integral type: " +
|
||||
inputType.fields()[0].dataType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String description() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static class JavaDoubleAverage implements AggregateFunction<State<Double>, Double> {
|
||||
@Override
|
||||
public State<Double> newAggregationState() {
|
||||
return new State<>(0.0, 0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public State<Double> update(State<Double> state, InternalRow input) {
|
||||
if (input.isNullAt(0)) {
|
||||
return state;
|
||||
}
|
||||
return new State<>(state.sum + input.getDouble(0), state.count + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double produceResult(State<Double> state) {
|
||||
return state.sum / state.count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public State<Double> merge(State<Double> leftState, State<Double> rightState) {
|
||||
return new State<>(leftState.sum + rightState.sum, leftState.count + rightState.count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType[] inputTypes() {
|
||||
return new DataType[] { DataTypes.DoubleType };
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType resultType() {
|
||||
return DataTypes.DoubleType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return "davg";
|
||||
}
|
||||
}
|
||||
|
||||
public static class State<T> implements Serializable {
|
||||
T sum, count;
|
||||
|
||||
State(T sum, T count) {
|
||||
this.sum = sum;
|
||||
this.count = count;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test.org.apache.spark.sql.connector.catalog.functions;
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow;
|
||||
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
|
||||
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
|
||||
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
|
||||
import org.apache.spark.sql.types.DataType;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.StringType;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
import org.apache.spark.unsafe.types.UTF8String;
|
||||
|
||||
public class JavaStrLen implements UnboundFunction {
|
||||
private final BoundFunction fn;
|
||||
|
||||
public JavaStrLen(BoundFunction fn) {
|
||||
this.fn = fn;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return "strlen";
|
||||
}
|
||||
|
||||
@Override
|
||||
public BoundFunction bind(StructType inputType) {
|
||||
if (inputType.fields().length != 1) {
|
||||
throw new UnsupportedOperationException("Expect exactly one argument");
|
||||
}
|
||||
|
||||
if (inputType.fields()[0].dataType() instanceof StringType) {
|
||||
return fn;
|
||||
}
|
||||
|
||||
throw new UnsupportedOperationException("Except StringType");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String description() {
|
||||
return "strlen: returns the length of the input string\n" +
|
||||
" strlen(string) -> int";
|
||||
}
|
||||
|
||||
public static class JavaStrLenDefault implements ScalarFunction<Integer> {
|
||||
@Override
|
||||
public DataType[] inputTypes() {
|
||||
return new DataType[] { DataTypes.StringType };
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType resultType() {
|
||||
return DataTypes.IntegerType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return "strlen";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer produceResult(InternalRow input) {
|
||||
String str = input.getString(0);
|
||||
return str.length();
|
||||
}
|
||||
}
|
||||
|
||||
public static class JavaStrLenMagic implements ScalarFunction<Integer> {
|
||||
@Override
|
||||
public DataType[] inputTypes() {
|
||||
return new DataType[] { DataTypes.StringType };
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType resultType() {
|
||||
return DataTypes.IntegerType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return "strlen";
|
||||
}
|
||||
|
||||
public int invoke(UTF8String str) {
|
||||
return str.toString().length();
|
||||
}
|
||||
}
|
||||
|
||||
public static class JavaStrLenNoImpl implements ScalarFunction<Integer> {
|
||||
@Override
|
||||
public DataType[] inputTypes() {
|
||||
return new DataType[] { DataTypes.StringType };
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType resultType() {
|
||||
return DataTypes.IntegerType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return "strlen";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,420 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector
|
||||
|
||||
import java.util
|
||||
import java.util.Collections
|
||||
|
||||
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaStrLen}
|
||||
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.sql.{AnalysisException, Row}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
|
||||
import org.apache.spark.sql.connector.catalog.functions._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
|
||||
private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
|
||||
|
||||
private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = {
|
||||
catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn)
|
||||
}
|
||||
|
||||
test("undefined function") {
|
||||
assert(intercept[AnalysisException](
|
||||
sql("SELECT testcat.non_exist('abc')").collect()
|
||||
).getMessage.contains("Undefined function"))
|
||||
}
|
||||
|
||||
test("non-function catalog") {
|
||||
withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) {
|
||||
assert(intercept[AnalysisException](
|
||||
sql("SELECT testcat.strlen('abc')").collect()
|
||||
).getMessage.contains("is not a FunctionCatalog"))
|
||||
}
|
||||
}
|
||||
|
||||
test("built-in with non-function catalog should still work") {
|
||||
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat",
|
||||
"spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) {
|
||||
checkAnswer(sql("SELECT length('abc')"), Row(3))
|
||||
}
|
||||
}
|
||||
|
||||
test("built-in with default v2 function catalog") {
|
||||
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
|
||||
checkAnswer(sql("SELECT length('abc')"), Row(3))
|
||||
}
|
||||
}
|
||||
|
||||
test("looking up higher-order function with non-session catalog") {
|
||||
checkAnswer(sql("SELECT transform(array(1, 2, 3), x -> x + 1)"),
|
||||
Row(Array(2, 3, 4)) :: Nil)
|
||||
}
|
||||
|
||||
test("built-in override with default v2 function catalog") {
|
||||
// a built-in function with the same name should take higher priority
|
||||
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
|
||||
addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
|
||||
checkAnswer(sql("SELECT length('abc')"), Row(3))
|
||||
}
|
||||
}
|
||||
|
||||
test("built-in override with non-session catalog") {
|
||||
addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
|
||||
checkAnswer(sql("SELECT length('abc')"), Row(3))
|
||||
}
|
||||
|
||||
test("temp function override with default v2 function catalog") {
|
||||
val className = "test.org.apache.spark.sql.JavaStringLength"
|
||||
sql(s"CREATE FUNCTION length AS '$className'")
|
||||
|
||||
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
|
||||
addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
|
||||
checkAnswer(sql("SELECT length('abc')"), Row(3))
|
||||
}
|
||||
}
|
||||
|
||||
test("view should use captured catalog and namespace for function lookup") {
|
||||
val viewName = "my_view"
|
||||
withView(viewName) {
|
||||
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "my_avg"), IntegralAverage)
|
||||
sql("USE ns")
|
||||
sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT my_avg(col1) FROM values (1), (2), (3)")
|
||||
}
|
||||
|
||||
// change default catalog and namespace and add a function with the same name but with no
|
||||
// implementation
|
||||
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat2") {
|
||||
catalog("testcat2").asInstanceOf[SupportsNamespaces]
|
||||
.createNamespace(Array("ns2"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns2"), "my_avg"), NoImplAverage)
|
||||
sql("USE ns2")
|
||||
checkAnswer(sql(s"SELECT * FROM $viewName"), Row(2.0) :: Nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("scalar function: with default produceResult method") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
|
||||
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("scalar function: with default produceResult method w/ expression") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
|
||||
checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("scalar function: lookup magic method") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic))
|
||||
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("scalar function: lookup magic method w/ expression") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic))
|
||||
checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("scalar function: bad magic method") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagic))
|
||||
assert(intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect())
|
||||
.getMessage.contains("Cannot find a compatible"))
|
||||
}
|
||||
|
||||
test("scalar function: bad magic method with default impl") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagicWithDefault))
|
||||
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("scalar function: no implementation found") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenNoImpl))
|
||||
intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect())
|
||||
}
|
||||
|
||||
test("scalar function: invalid parameter type or length") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
|
||||
|
||||
assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen(42)"))
|
||||
.getMessage.contains("Expect StringType"))
|
||||
assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('a', 'b')"))
|
||||
.getMessage.contains("Expect exactly one argument"))
|
||||
}
|
||||
|
||||
test("scalar function: default produceResult in Java") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"),
|
||||
new JavaStrLen(new JavaStrLenDefault))
|
||||
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("scalar function: magic method in Java") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"),
|
||||
new JavaStrLen(new JavaStrLenMagic))
|
||||
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("scalar function: no implementation found in Java") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"),
|
||||
new JavaStrLen(new JavaStrLenNoImpl))
|
||||
assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect())
|
||||
.getMessage.contains("neither implement magic method nor override 'produceResult'"))
|
||||
}
|
||||
|
||||
test("bad bound function (neither scalar nor aggregate)") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction))
|
||||
|
||||
assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')"))
|
||||
.getMessage.contains("does not implement ScalarFunction or AggregateFunction"))
|
||||
}
|
||||
|
||||
test("aggregate function: lookup int average") {
|
||||
import testImplicits._
|
||||
val t = "testcat.ns.t"
|
||||
withTable(t) {
|
||||
addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
|
||||
|
||||
(1 to 100).toDF("i").write.saveAsTable(t)
|
||||
checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("aggregate function: lookup long average") {
|
||||
import testImplicits._
|
||||
val t = "testcat.ns.t"
|
||||
withTable(t) {
|
||||
addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
|
||||
|
||||
(1L to 100L).toDF("i").write.saveAsTable(t)
|
||||
checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("aggregate function: lookup double average in Java") {
|
||||
import testImplicits._
|
||||
val t = "testcat.ns.t"
|
||||
withTable(t) {
|
||||
addFunction(Identifier.of(Array("ns"), "avg"), new JavaAverage)
|
||||
|
||||
Seq(1.toDouble, 2.toDouble, 3.toDouble).toDF("i").write.saveAsTable(t)
|
||||
checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(2.0) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("aggregate function: lookup int average w/ expression") {
|
||||
import testImplicits._
|
||||
val t = "testcat.ns.t"
|
||||
withTable(t) {
|
||||
addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
|
||||
|
||||
(1 to 100).toDF("i").write.saveAsTable(t)
|
||||
checkAnswer(sql(s"SELECT testcat.ns.avg(i * 10) from $t"), Row(505) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("aggregate function: unsupported input type") {
|
||||
import testImplicits._
|
||||
val t = "testcat.ns.t"
|
||||
withTable(t) {
|
||||
addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
|
||||
|
||||
Seq(1.toShort, 2.toShort).toDF("i").write.saveAsTable(t)
|
||||
assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t"))
|
||||
.getMessage.contains("Unsupported non-integral type: ShortType"))
|
||||
}
|
||||
}
|
||||
|
||||
private case class StrLen(impl: BoundFunction) extends UnboundFunction {
|
||||
override def description(): String =
|
||||
"""strlen: returns the length of the input string
|
||||
| strlen(string) -> int""".stripMargin
|
||||
override def name(): String = "strlen"
|
||||
|
||||
override def bind(inputType: StructType): BoundFunction = {
|
||||
if (inputType.fields.length != 1) {
|
||||
throw new UnsupportedOperationException("Expect exactly one argument");
|
||||
}
|
||||
inputType.fields(0).dataType match {
|
||||
case StringType => impl
|
||||
case _ =>
|
||||
throw new UnsupportedOperationException("Expect StringType")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private case object StrLenDefault extends ScalarFunction[Int] {
|
||||
override def inputTypes(): Array[DataType] = Array(StringType)
|
||||
override def resultType(): DataType = IntegerType
|
||||
override def name(): String = "strlen_default"
|
||||
|
||||
override def produceResult(input: InternalRow): Int = {
|
||||
val s = input.getString(0)
|
||||
s.length
|
||||
}
|
||||
}
|
||||
|
||||
private case object StrLenMagic extends ScalarFunction[Int] {
|
||||
override def inputTypes(): Array[DataType] = Array(StringType)
|
||||
override def resultType(): DataType = IntegerType
|
||||
override def name(): String = "strlen_magic"
|
||||
|
||||
def invoke(input: UTF8String): Int = {
|
||||
input.toString.length
|
||||
}
|
||||
}
|
||||
|
||||
private case object StrLenBadMagic extends ScalarFunction[Int] {
|
||||
override def inputTypes(): Array[DataType] = Array(StringType)
|
||||
override def resultType(): DataType = IntegerType
|
||||
override def name(): String = "strlen_bad_magic"
|
||||
|
||||
def invoke(input: String): Int = {
|
||||
input.length
|
||||
}
|
||||
}
|
||||
|
||||
private case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
|
||||
override def inputTypes(): Array[DataType] = Array(StringType)
|
||||
override def resultType(): DataType = IntegerType
|
||||
override def name(): String = "strlen_bad_magic"
|
||||
|
||||
def invoke(input: String): Int = {
|
||||
input.length
|
||||
}
|
||||
|
||||
override def produceResult(input: InternalRow): Int = {
|
||||
val s = input.getString(0)
|
||||
s.length
|
||||
}
|
||||
}
|
||||
|
||||
private case object StrLenNoImpl extends ScalarFunction[Int] {
|
||||
override def inputTypes(): Array[DataType] = Array(StringType)
|
||||
override def resultType(): DataType = IntegerType
|
||||
override def name(): String = "strlen_noimpl"
|
||||
}
|
||||
|
||||
private case object BadBoundFunction extends BoundFunction {
|
||||
override def inputTypes(): Array[DataType] = Array(StringType)
|
||||
override def resultType(): DataType = IntegerType
|
||||
override def name(): String = "bad_bound_func"
|
||||
}
|
||||
|
||||
object IntegralAverage extends UnboundFunction {
|
||||
override def name(): String = "iavg"
|
||||
|
||||
override def bind(inputType: StructType): BoundFunction = {
|
||||
if (inputType.fields.length > 1) {
|
||||
throw new UnsupportedOperationException("Too many arguments")
|
||||
}
|
||||
|
||||
inputType.fields(0).dataType match {
|
||||
case _: IntegerType => IntAverage
|
||||
case _: LongType => LongAverage
|
||||
case dataType =>
|
||||
throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType")
|
||||
}
|
||||
}
|
||||
|
||||
override def description(): String =
|
||||
"""iavg: produces an average using integer division, ignoring nulls
|
||||
| iavg(int) -> int
|
||||
| iavg(bigint) -> bigint""".stripMargin
|
||||
}
|
||||
|
||||
object IntAverage extends AggregateFunction[(Int, Int), Int] {
|
||||
override def name(): String = "iavg"
|
||||
override def inputTypes(): Array[DataType] = Array(IntegerType)
|
||||
override def resultType(): DataType = IntegerType
|
||||
|
||||
override def newAggregationState(): (Int, Int) = (0, 0)
|
||||
|
||||
override def update(state: (Int, Int), input: InternalRow): (Int, Int) = {
|
||||
if (input.isNullAt(0)) {
|
||||
state
|
||||
} else {
|
||||
val i = input.getInt(0)
|
||||
state match {
|
||||
case (_, 0) =>
|
||||
(i, 1)
|
||||
case (total, count) =>
|
||||
(total + i, count + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = {
|
||||
(leftState._1 + rightState._1, leftState._2 + rightState._2)
|
||||
}
|
||||
|
||||
override def produceResult(state: (Int, Int)): Int = state._1 / state._2
|
||||
}
|
||||
|
||||
object LongAverage extends AggregateFunction[(Long, Long), Long] {
|
||||
override def name(): String = "iavg"
|
||||
override def inputTypes(): Array[DataType] = Array(LongType)
|
||||
override def resultType(): DataType = LongType
|
||||
|
||||
override def newAggregationState(): (Long, Long) = (0L, 0L)
|
||||
|
||||
override def update(state: (Long, Long), input: InternalRow): (Long, Long) = {
|
||||
if (input.isNullAt(0)) {
|
||||
state
|
||||
} else {
|
||||
val l = input.getLong(0)
|
||||
state match {
|
||||
case (_, 0L) =>
|
||||
(l, 1)
|
||||
case (total, count) =>
|
||||
(total + l, count + 1L)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = {
|
||||
(leftState._1 + rightState._1, leftState._2 + rightState._2)
|
||||
}
|
||||
|
||||
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
|
||||
}
|
||||
|
||||
object NoImplAverage extends UnboundFunction {
|
||||
override def name(): String = "no_impl_avg"
|
||||
override def description(): String = name()
|
||||
|
||||
override def bind(inputType: StructType): BoundFunction = {
|
||||
throw new UnsupportedOperationException(s"Not implemented")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.connector
|
|||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.sql.QueryTest
|
||||
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryPartitionTableCatalog, InMemoryTableCatalog, StagingInMemoryTableCatalog}
|
||||
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryCatalog, InMemoryPartitionTableCatalog, StagingInMemoryTableCatalog}
|
||||
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
|
||||
|
@ -32,11 +32,11 @@ trait DatasourceV2SQLBase
|
|||
}
|
||||
|
||||
before {
|
||||
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
|
||||
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
|
||||
spark.conf.set("spark.sql.catalog.testpart", classOf[InMemoryPartitionTableCatalog].getName)
|
||||
spark.conf.set(
|
||||
"spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryTableCatalog].getName)
|
||||
spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName)
|
||||
spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryCatalog].getName)
|
||||
spark.conf.set(
|
||||
V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName)
|
||||
|
||||
|
|
Loading…
Reference in a new issue