diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java new file mode 100644 index 0000000000..651c9148c4 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; + +/** + * Catalog methods for working with Functions. + */ +public interface FunctionCatalog extends CatalogPlugin { + + /** + * List the functions in a namespace from the catalog. + *
+ * If there are no functions in the namespace, implementations should return an empty array. + * + * @param namespace a multi-part namespace + * @return an array of Identifiers for functions + * @throws NoSuchNamespaceException If the namespace does not exist (optional). + */ + Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceException; + + /** + * Load a function by {@link Identifier identifier} from the catalog. + * + * @param ident a function identifier + * @return an unbound function instance + * @throws NoSuchFunctionException If the function doesn't exist + */ + UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException; + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java new file mode 100644 index 0000000000..6982ebb329 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataType; + +import java.io.Serializable; + +/** + * Interface for a function that produces a result value by aggregating over multiple input rows. + *
+ * For each input row, Spark will call an update method that corresponds to the + * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by + * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call + * update with {@link InternalRow}. + *
+ * The JVM type of result values produced by this function must be the type used by Spark's + * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + *
+ * All implementations must support partial aggregation by implementing merge so that Spark can + * partially aggregate and shuffle intermediate results, instead of shuffling all rows for an + * aggregate. This reduces the impact of data skew and the amount of data shuffled to produce the + * result. + *
+ * Intermediate aggregation state must be {@link Serializable} so that state produced by parallel
+ * tasks can be serialized, shuffled, and then merged to produce a final result.
+ *
+ * @param
+ * This method is called one or more times for every group of values to initialize intermediate
+ * aggregation state. More than one intermediate aggregation state variable may be used when the
+ * aggregation is run in parallel tasks.
+ *
+ * Implementations that return null must support null state passed into all other methods.
+ *
+ * @return a state instance or null
+ */
+ S newAggregationState();
+
+ /**
+ * Update the aggregation state with a new row.
+ *
+ * This is called for each row in a group to update an intermediate aggregation state.
+ *
+ * @param state intermediate aggregation state
+ * @param input an input row
+ * @return updated aggregation state
+ */
+ default S update(S state, InternalRow input) {
+ throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update");
+ }
+
+ /**
+ * Merge two partial aggregation states.
+ *
+ * This is called to merge intermediate aggregation states that were produced by parallel tasks.
+ *
+ * @param leftState intermediate aggregation state
+ * @param rightState intermediate aggregation state
+ * @return combined aggregation state
+ */
+ S merge(S leftState, S rightState);
+
+ /**
+ * Produce the aggregation result based on intermediate state.
+ *
+ * @param state intermediate aggregation state
+ * @return a result value
+ */
+ R produceResult(S state);
+
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java
new file mode 100644
index 0000000000..c53f94a168
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.functions;
+
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.UUID;
+
+/**
+ * Represents a function that is bound to an input type.
+ */
+public interface BoundFunction extends Function {
+
+ /**
+ * Returns the required {@link DataType data types} of the input values to this function.
+ *
+ * If the types returned differ from the types passed to {@link UnboundFunction#bind(StructType)},
+ * Spark will cast input values to the required data types. This allows implementations to
+ * delegate input value casting to Spark.
+ *
+ * @return an array of input value data types
+ */
+ DataType[] inputTypes();
+
+ /**
+ * Returns the {@link DataType data type} of values produced by this function.
+ *
+ * For example, a "plus" function may return {@link IntegerType} when it is bound to arguments
+ * that are also {@link IntegerType}.
+ *
+ * @return a data type for values produced by this function
+ */
+ DataType resultType();
+
+ /**
+ * Returns whether the values produced by this function may be null.
+ *
+ * For example, a "plus" function may return false when it is bound to arguments that are always
+ * non-null, but true when either argument may be null.
+ *
+ * @return true if values produced by this function may be null, false otherwise
+ */
+ default boolean isResultNullable() {
+ return true;
+ }
+
+ /**
+ * Returns whether this function result is deterministic.
+ *
+ * By default, functions are assumed to be deterministic. Functions that are not deterministic
+ * should override this method so that Spark can ensure the function runs only once for a given
+ * input.
+ *
+ * @return true if this function is deterministic, false otherwise
+ */
+ default boolean isDeterministic() {
+ return true;
+ }
+
+ /**
+ * Returns the canonical name of this function, used to determine if functions are equivalent.
+ *
+ * The canonical name is used to determine whether two functions are the same when loaded by
+ * different catalogs. For example, the same catalog implementation may be used for by two
+ * environments, "prod" and "test". Functions produced by the catalogs may be equivalent, but
+ * loaded using different names, like "test.func_name" and "prod.func_name".
+ *
+ * Names returned by this function should be unique and unlikely to conflict with similar
+ * functions in other catalogs. For example, many catalogs may define a "bucket" function with a
+ * different implementation. Adding context, like "com.mycompany.bucket(string)", is recommended
+ * to avoid unintentional collisions.
+ *
+ * @return a canonical name for this function
+ */
+ default String canonicalName() {
+ // by default, use a random UUID so a function is never equivalent to another, even itself.
+ // this method is not required so that generated implementations (or careless ones) are not
+ // added and forgotten. for example, returning "" as a place-holder could cause unnecessary
+ // bugs if not replaced before release.
+ return UUID.randomUUID().toString();
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Function.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Function.java
new file mode 100644
index 0000000000..b7f14eb271
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Function.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.functions;
+
+import java.io.Serializable;
+
+/**
+ * Base class for user-defined functions.
+ */
+public interface Function extends Serializable {
+
+ /**
+ * A name to identify this function. Implementations should provide a meaningful name, like the
+ * database and function name from the catalog.
+ */
+ String name();
+
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
new file mode 100644
index 0000000000..c2106a21c4
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.functions;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.DataType;
+
+/**
+ * Interface for a function that produces a result value for each input row.
+ *
+ * For each input row, Spark will call a produceResult method that corresponds to the
+ * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
+ * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
+ * {@link #produceResult(InternalRow)}.
+ *
+ * The JVM type of result values produced by this function must be the type used by Spark's
+ * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
+ *
+ * @param
+ * If the input type is not supported, implementations must throw
+ * {@link UnsupportedOperationException}.
+ *
+ * For example, a "length" function that only supports a single string argument should throw
+ * UnsupportedOperationException if the struct has more than one field or if that field is not a
+ * string, and it may optionally throw if the field is nullable.
+ *
+ * @param inputType a struct type for inputs that will be passed to the bound function
+ * @return a function that can process rows with the given input type
+ * @throws UnsupportedOperationException If the function cannot be applied to the input type
+ */
+ BoundFunction bind(StructType inputType);
+
+ /**
+ * Returns Function documentation.
+ *
+ * @return this function's documentation
+ */
+ String description();
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
index 8a1913b40b..ba5a9c618c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
@@ -65,10 +65,20 @@ class NoSuchPartitionException(message: String) extends AnalysisException(messag
class NoSuchPermanentFunctionException(db: String, func: String)
extends AnalysisException(s"Function '$func' not found in database '$db'")
-class NoSuchFunctionException(db: String, func: String, cause: Option[Throwable] = None)
- extends AnalysisException(
- s"Undefined function: '$func'. This function is neither a registered temporary function nor " +
- s"a permanent function registered in the database '$db'.", cause = cause)
+class NoSuchFunctionException(
+ msg: String,
+ cause: Option[Throwable]) extends AnalysisException(msg, cause = cause) {
+
+ def this(db: String, func: String, cause: Option[Throwable] = None) = {
+ this(s"Undefined function: '$func'. " +
+ s"This function is neither a registered temporary function nor " +
+ s"a permanent function registered in the database '$db'.", cause = cause)
+ }
+
+ def this(identifier: Identifier) = {
+ this(s"Undefined function: ${identifier.quoted}", cause = None)
+ }
+}
class NoSuchPartitionsException(message: String) extends AnalysisException(message) {
def this(db: String, table: String, specs: Seq[TablePartitionSpec]) = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala
new file mode 100644
index 0000000000..5e5a7d4f05
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.functions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType}
+
+class AggregateFunctionSuite extends SparkFunSuite {
+ test("Test simple iavg(int)") {
+ val rows = Seq(InternalRow(2), InternalRow(2), InternalRow(2))
+
+ val bound = IntegralAverage.bind(new StructType().add("foo", IntegerType, nullable = false))
+ assert(bound.isInstanceOf[AggregateFunction[_, _]])
+ val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]]
+
+ val finalState = rows.foldLeft(udaf.newAggregationState()) { (state, row) =>
+ udaf.update(state, row)
+ }
+
+ assert(udaf.produceResult(finalState) == 2)
+ }
+
+ test("Test simple iavg(long)") {
+ val bigValue = 9762097370951020L
+ val rows = Seq(InternalRow(bigValue + 2), InternalRow(bigValue), InternalRow(bigValue - 2))
+
+ val bound = IntegralAverage.bind(new StructType().add("foo", LongType, nullable = false))
+ assert(bound.isInstanceOf[AggregateFunction[_, _]])
+ val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]]
+
+ val finalState = rows.foldLeft(udaf.newAggregationState()) { (state, row) =>
+ udaf.update(state, row)
+ }
+
+ assert(udaf.produceResult(finalState) == bigValue)
+ }
+
+ test("Test associative iavg(long)") {
+ val bigValue = 7620099737951020L
+ val rows = Seq(InternalRow(bigValue + 2), InternalRow(bigValue), InternalRow(bigValue - 2))
+
+ val bound = IntegralAverage.bind(new StructType().add("foo", LongType, nullable = false))
+ assert(bound.isInstanceOf[AggregateFunction[_, _]])
+ val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]]
+
+ val state1 = rows.foldLeft(udaf.newAggregationState()) { (state, row) =>
+ udaf.update(state, row)
+ }
+ val state2 = rows.foldLeft(udaf.newAggregationState()) { (state, row) =>
+ udaf.update(state, row)
+ }
+ val finalState = udaf.merge(state1, state2)
+
+ assert(udaf.produceResult(finalState) == bigValue)
+ }
+}
+
+object IntegralAverage extends UnboundFunction {
+ override def name(): String = "iavg"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
+
+ if (inputType.fields(0).nullable) {
+ throw new UnsupportedOperationException("Nullable values are not supported")
+ }
+
+ inputType.fields(0).dataType match {
+ case _: IntegerType => IntAverage
+ case _: LongType => LongAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType")
+ }
+ }
+
+ override def description(): String =
+ """iavg: produces an average using integer division
+ | iavg(int not null) -> int
+ | iavg(bigint not null) -> bigint""".stripMargin
+}
+
+object IntAverage extends AggregateFunction[(Int, Int), Int] {
+
+ override def inputTypes(): Array[DataType] = Array(IntegerType)
+
+ override def name(): String = "iavg"
+
+ override def newAggregationState(): (Int, Int) = (0, 0)
+
+ override def update(state: (Int, Int), input: InternalRow): (Int, Int) = {
+ val i = input.getInt(0)
+ val (total, count) = state
+ (total + i, count + 1)
+ }
+
+ override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = {
+ (leftState._1 + rightState._1, leftState._2 + rightState._2)
+ }
+
+ override def produceResult(state: (Int, Int)): Int = state._1 / state._2
+
+ override def resultType(): DataType = IntegerType
+}
+
+object LongAverage extends AggregateFunction[(Long, Long), Long] {
+
+ override def inputTypes(): Array[DataType] = Array(LongType)
+
+ override def name(): String = "iavg"
+
+ override def newAggregationState(): (Long, Long) = (0L, 0L)
+
+ override def update(state: (Long, Long), input: InternalRow): (Long, Long) = {
+ val l = input.getLong(0)
+ state match {
+ case (_, 0L) =>
+ (l, 1)
+ case (total, count) =>
+ (total + l, count + 1L)
+ }
+ }
+
+ override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = {
+ (leftState._1 + rightState._1, leftState._2 + rightState._2)
+ }
+
+ override def produceResult(state: (Long, Long)): Long = state._1 / state._2
+
+ override def resultType(): DataType = IntegerType
+}
the JVM type for the aggregation's intermediate state; must be {@link Serializable}
+ * @param extends BoundFunction {
+
+ /**
+ * Initialize state for an aggregation.
+ *