[SPARK-27658][SQL] Add FunctionCatalog API
## What changes were proposed in this pull request? This adds a new API for catalog plugins that exposes functions to Spark. The API can list and load functions. This does not include create, delete, or alter operations. - [Design Document](https://docs.google.com/document/d/1PLBieHIlxZjmoUB0ERF-VozCRJ0xw2j3qKvUNWpWA2U/edit?usp=sharing) There are 3 types of functions defined: * A `ScalarFunction` that produces a value for every call * An `AggregateFunction` that produces a value after updates for a group of rows Functions loaded from the catalog by name as `UnboundFunction`. Once input arguments are determined `bind` is called on the unbound function to get a `BoundFunction` implementation that is one of the 3 types above. Binding can fail if the function doesn't support the input type. `BoundFunction` returns the result type produced by the function. ## How was this patch tested? This includes a test that demonstrates the new API. Closes #24559 from rdblue/SPARK-27658-add-function-catalog-api. Authored-by: Ryan Blue <blue@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
06c09a79b3
commit
3c7d6c38e8
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.connector.catalog;
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException;
|
||||||
|
import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
|
||||||
|
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Catalog methods for working with Functions.
|
||||||
|
*/
|
||||||
|
public interface FunctionCatalog extends CatalogPlugin {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List the functions in a namespace from the catalog.
|
||||||
|
* <p>
|
||||||
|
* If there are no functions in the namespace, implementations should return an empty array.
|
||||||
|
*
|
||||||
|
* @param namespace a multi-part namespace
|
||||||
|
* @return an array of Identifiers for functions
|
||||||
|
* @throws NoSuchNamespaceException If the namespace does not exist (optional).
|
||||||
|
*/
|
||||||
|
Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a function by {@link Identifier identifier} from the catalog.
|
||||||
|
*
|
||||||
|
* @param ident a function identifier
|
||||||
|
* @return an unbound function instance
|
||||||
|
* @throws NoSuchFunctionException If the function doesn't exist
|
||||||
|
*/
|
||||||
|
UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException;
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,94 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.connector.catalog.functions;
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow;
|
||||||
|
import org.apache.spark.sql.types.DataType;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for a function that produces a result value by aggregating over multiple input rows.
|
||||||
|
* <p>
|
||||||
|
* For each input row, Spark will call an update method that corresponds to the
|
||||||
|
* {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
|
||||||
|
* Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
|
||||||
|
* update with {@link InternalRow}.
|
||||||
|
* <p>
|
||||||
|
* The JVM type of result values produced by this function must be the type used by Spark's
|
||||||
|
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
|
||||||
|
* <p>
|
||||||
|
* All implementations must support partial aggregation by implementing merge so that Spark can
|
||||||
|
* partially aggregate and shuffle intermediate results, instead of shuffling all rows for an
|
||||||
|
* aggregate. This reduces the impact of data skew and the amount of data shuffled to produce the
|
||||||
|
* result.
|
||||||
|
* <p>
|
||||||
|
* Intermediate aggregation state must be {@link Serializable} so that state produced by parallel
|
||||||
|
* tasks can be serialized, shuffled, and then merged to produce a final result.
|
||||||
|
*
|
||||||
|
* @param <S> the JVM type for the aggregation's intermediate state; must be {@link Serializable}
|
||||||
|
* @param <R> the JVM type of result values
|
||||||
|
*/
|
||||||
|
public interface AggregateFunction<S extends Serializable, R> extends BoundFunction {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize state for an aggregation.
|
||||||
|
* <p>
|
||||||
|
* This method is called one or more times for every group of values to initialize intermediate
|
||||||
|
* aggregation state. More than one intermediate aggregation state variable may be used when the
|
||||||
|
* aggregation is run in parallel tasks.
|
||||||
|
* <p>
|
||||||
|
* Implementations that return null must support null state passed into all other methods.
|
||||||
|
*
|
||||||
|
* @return a state instance or null
|
||||||
|
*/
|
||||||
|
S newAggregationState();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the aggregation state with a new row.
|
||||||
|
* <p>
|
||||||
|
* This is called for each row in a group to update an intermediate aggregation state.
|
||||||
|
*
|
||||||
|
* @param state intermediate aggregation state
|
||||||
|
* @param input an input row
|
||||||
|
* @return updated aggregation state
|
||||||
|
*/
|
||||||
|
default S update(S state, InternalRow input) {
|
||||||
|
throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Merge two partial aggregation states.
|
||||||
|
* <p>
|
||||||
|
* This is called to merge intermediate aggregation states that were produced by parallel tasks.
|
||||||
|
*
|
||||||
|
* @param leftState intermediate aggregation state
|
||||||
|
* @param rightState intermediate aggregation state
|
||||||
|
* @return combined aggregation state
|
||||||
|
*/
|
||||||
|
S merge(S leftState, S rightState);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce the aggregation result based on intermediate state.
|
||||||
|
*
|
||||||
|
* @param state intermediate aggregation state
|
||||||
|
* @return a result value
|
||||||
|
*/
|
||||||
|
R produceResult(S state);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.connector.catalog.functions;
|
||||||
|
|
||||||
|
import org.apache.spark.sql.types.DataType;
|
||||||
|
import org.apache.spark.sql.types.IntegerType;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a function that is bound to an input type.
|
||||||
|
*/
|
||||||
|
public interface BoundFunction extends Function {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the required {@link DataType data types} of the input values to this function.
|
||||||
|
* <p>
|
||||||
|
* If the types returned differ from the types passed to {@link UnboundFunction#bind(StructType)},
|
||||||
|
* Spark will cast input values to the required data types. This allows implementations to
|
||||||
|
* delegate input value casting to Spark.
|
||||||
|
*
|
||||||
|
* @return an array of input value data types
|
||||||
|
*/
|
||||||
|
DataType[] inputTypes();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link DataType data type} of values produced by this function.
|
||||||
|
* <p>
|
||||||
|
* For example, a "plus" function may return {@link IntegerType} when it is bound to arguments
|
||||||
|
* that are also {@link IntegerType}.
|
||||||
|
*
|
||||||
|
* @return a data type for values produced by this function
|
||||||
|
*/
|
||||||
|
DataType resultType();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns whether the values produced by this function may be null.
|
||||||
|
* <p>
|
||||||
|
* For example, a "plus" function may return false when it is bound to arguments that are always
|
||||||
|
* non-null, but true when either argument may be null.
|
||||||
|
*
|
||||||
|
* @return true if values produced by this function may be null, false otherwise
|
||||||
|
*/
|
||||||
|
default boolean isResultNullable() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns whether this function result is deterministic.
|
||||||
|
* <p>
|
||||||
|
* By default, functions are assumed to be deterministic. Functions that are not deterministic
|
||||||
|
* should override this method so that Spark can ensure the function runs only once for a given
|
||||||
|
* input.
|
||||||
|
*
|
||||||
|
* @return true if this function is deterministic, false otherwise
|
||||||
|
*/
|
||||||
|
default boolean isDeterministic() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the canonical name of this function, used to determine if functions are equivalent.
|
||||||
|
* <p>
|
||||||
|
* The canonical name is used to determine whether two functions are the same when loaded by
|
||||||
|
* different catalogs. For example, the same catalog implementation may be used for by two
|
||||||
|
* environments, "prod" and "test". Functions produced by the catalogs may be equivalent, but
|
||||||
|
* loaded using different names, like "test.func_name" and "prod.func_name".
|
||||||
|
* <p>
|
||||||
|
* Names returned by this function should be unique and unlikely to conflict with similar
|
||||||
|
* functions in other catalogs. For example, many catalogs may define a "bucket" function with a
|
||||||
|
* different implementation. Adding context, like "com.mycompany.bucket(string)", is recommended
|
||||||
|
* to avoid unintentional collisions.
|
||||||
|
*
|
||||||
|
* @return a canonical name for this function
|
||||||
|
*/
|
||||||
|
default String canonicalName() {
|
||||||
|
// by default, use a random UUID so a function is never equivalent to another, even itself.
|
||||||
|
// this method is not required so that generated implementations (or careless ones) are not
|
||||||
|
// added and forgotten. for example, returning "" as a place-holder could cause unnecessary
|
||||||
|
// bugs if not replaced before release.
|
||||||
|
return UUID.randomUUID().toString();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.connector.catalog.functions;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for user-defined functions.
|
||||||
|
*/
|
||||||
|
public interface Function extends Serializable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A name to identify this function. Implementations should provide a meaningful name, like the
|
||||||
|
* database and function name from the catalog.
|
||||||
|
*/
|
||||||
|
String name();
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.connector.catalog.functions;
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow;
|
||||||
|
import org.apache.spark.sql.types.DataType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for a function that produces a result value for each input row.
|
||||||
|
* <p>
|
||||||
|
* For each input row, Spark will call a produceResult method that corresponds to the
|
||||||
|
* {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
|
||||||
|
* Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
|
||||||
|
* {@link #produceResult(InternalRow)}.
|
||||||
|
* <p>
|
||||||
|
* The JVM type of result values produced by this function must be the type used by Spark's
|
||||||
|
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
|
||||||
|
*
|
||||||
|
* @param <R> the JVM type of result values
|
||||||
|
*/
|
||||||
|
public interface ScalarFunction<R> extends BoundFunction {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies the function to an input row to produce a value.
|
||||||
|
*
|
||||||
|
* @param input an input row
|
||||||
|
* @return a result value
|
||||||
|
*/
|
||||||
|
default R produceResult(InternalRow input) {
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"Cannot find a compatible ScalarFunction#produceResult");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.connector.catalog.functions;
|
||||||
|
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a user-defined function that is not bound to input types.
|
||||||
|
*/
|
||||||
|
public interface UnboundFunction extends Function {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind this function to an input type.
|
||||||
|
* <p>
|
||||||
|
* If the input type is not supported, implementations must throw
|
||||||
|
* {@link UnsupportedOperationException}.
|
||||||
|
* <p>
|
||||||
|
* For example, a "length" function that only supports a single string argument should throw
|
||||||
|
* UnsupportedOperationException if the struct has more than one field or if that field is not a
|
||||||
|
* string, and it may optionally throw if the field is nullable.
|
||||||
|
*
|
||||||
|
* @param inputType a struct type for inputs that will be passed to the bound function
|
||||||
|
* @return a function that can process rows with the given input type
|
||||||
|
* @throws UnsupportedOperationException If the function cannot be applied to the input type
|
||||||
|
*/
|
||||||
|
BoundFunction bind(StructType inputType);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns Function documentation.
|
||||||
|
*
|
||||||
|
* @return this function's documentation
|
||||||
|
*/
|
||||||
|
String description();
|
||||||
|
|
||||||
|
}
|
|
@ -65,10 +65,20 @@ class NoSuchPartitionException(message: String) extends AnalysisException(messag
|
||||||
class NoSuchPermanentFunctionException(db: String, func: String)
|
class NoSuchPermanentFunctionException(db: String, func: String)
|
||||||
extends AnalysisException(s"Function '$func' not found in database '$db'")
|
extends AnalysisException(s"Function '$func' not found in database '$db'")
|
||||||
|
|
||||||
class NoSuchFunctionException(db: String, func: String, cause: Option[Throwable] = None)
|
class NoSuchFunctionException(
|
||||||
extends AnalysisException(
|
msg: String,
|
||||||
s"Undefined function: '$func'. This function is neither a registered temporary function nor " +
|
cause: Option[Throwable]) extends AnalysisException(msg, cause = cause) {
|
||||||
|
|
||||||
|
def this(db: String, func: String, cause: Option[Throwable] = None) = {
|
||||||
|
this(s"Undefined function: '$func'. " +
|
||||||
|
s"This function is neither a registered temporary function nor " +
|
||||||
s"a permanent function registered in the database '$db'.", cause = cause)
|
s"a permanent function registered in the database '$db'.", cause = cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
def this(identifier: Identifier) = {
|
||||||
|
this(s"Undefined function: ${identifier.quoted}", cause = None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class NoSuchPartitionsException(message: String) extends AnalysisException(message) {
|
class NoSuchPartitionsException(message: String) extends AnalysisException(message) {
|
||||||
def this(db: String, table: String, specs: Seq[TablePartitionSpec]) = {
|
def this(db: String, table: String, specs: Seq[TablePartitionSpec]) = {
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.connector.catalog.functions
|
||||||
|
|
||||||
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
|
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType}
|
||||||
|
|
||||||
|
class AggregateFunctionSuite extends SparkFunSuite {
|
||||||
|
test("Test simple iavg(int)") {
|
||||||
|
val rows = Seq(InternalRow(2), InternalRow(2), InternalRow(2))
|
||||||
|
|
||||||
|
val bound = IntegralAverage.bind(new StructType().add("foo", IntegerType, nullable = false))
|
||||||
|
assert(bound.isInstanceOf[AggregateFunction[_, _]])
|
||||||
|
val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]]
|
||||||
|
|
||||||
|
val finalState = rows.foldLeft(udaf.newAggregationState()) { (state, row) =>
|
||||||
|
udaf.update(state, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(udaf.produceResult(finalState) == 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Test simple iavg(long)") {
|
||||||
|
val bigValue = 9762097370951020L
|
||||||
|
val rows = Seq(InternalRow(bigValue + 2), InternalRow(bigValue), InternalRow(bigValue - 2))
|
||||||
|
|
||||||
|
val bound = IntegralAverage.bind(new StructType().add("foo", LongType, nullable = false))
|
||||||
|
assert(bound.isInstanceOf[AggregateFunction[_, _]])
|
||||||
|
val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]]
|
||||||
|
|
||||||
|
val finalState = rows.foldLeft(udaf.newAggregationState()) { (state, row) =>
|
||||||
|
udaf.update(state, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(udaf.produceResult(finalState) == bigValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Test associative iavg(long)") {
|
||||||
|
val bigValue = 7620099737951020L
|
||||||
|
val rows = Seq(InternalRow(bigValue + 2), InternalRow(bigValue), InternalRow(bigValue - 2))
|
||||||
|
|
||||||
|
val bound = IntegralAverage.bind(new StructType().add("foo", LongType, nullable = false))
|
||||||
|
assert(bound.isInstanceOf[AggregateFunction[_, _]])
|
||||||
|
val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]]
|
||||||
|
|
||||||
|
val state1 = rows.foldLeft(udaf.newAggregationState()) { (state, row) =>
|
||||||
|
udaf.update(state, row)
|
||||||
|
}
|
||||||
|
val state2 = rows.foldLeft(udaf.newAggregationState()) { (state, row) =>
|
||||||
|
udaf.update(state, row)
|
||||||
|
}
|
||||||
|
val finalState = udaf.merge(state1, state2)
|
||||||
|
|
||||||
|
assert(udaf.produceResult(finalState) == bigValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object IntegralAverage extends UnboundFunction {
|
||||||
|
override def name(): String = "iavg"
|
||||||
|
|
||||||
|
override def bind(inputType: StructType): BoundFunction = {
|
||||||
|
if (inputType.fields.length > 1) {
|
||||||
|
throw new UnsupportedOperationException("Too many arguments")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputType.fields(0).nullable) {
|
||||||
|
throw new UnsupportedOperationException("Nullable values are not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
inputType.fields(0).dataType match {
|
||||||
|
case _: IntegerType => IntAverage
|
||||||
|
case _: LongType => LongAverage
|
||||||
|
case dataType =>
|
||||||
|
throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def description(): String =
|
||||||
|
"""iavg: produces an average using integer division
|
||||||
|
| iavg(int not null) -> int
|
||||||
|
| iavg(bigint not null) -> bigint""".stripMargin
|
||||||
|
}
|
||||||
|
|
||||||
|
object IntAverage extends AggregateFunction[(Int, Int), Int] {
|
||||||
|
|
||||||
|
override def inputTypes(): Array[DataType] = Array(IntegerType)
|
||||||
|
|
||||||
|
override def name(): String = "iavg"
|
||||||
|
|
||||||
|
override def newAggregationState(): (Int, Int) = (0, 0)
|
||||||
|
|
||||||
|
override def update(state: (Int, Int), input: InternalRow): (Int, Int) = {
|
||||||
|
val i = input.getInt(0)
|
||||||
|
val (total, count) = state
|
||||||
|
(total + i, count + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = {
|
||||||
|
(leftState._1 + rightState._1, leftState._2 + rightState._2)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def produceResult(state: (Int, Int)): Int = state._1 / state._2
|
||||||
|
|
||||||
|
override def resultType(): DataType = IntegerType
|
||||||
|
}
|
||||||
|
|
||||||
|
object LongAverage extends AggregateFunction[(Long, Long), Long] {
|
||||||
|
|
||||||
|
override def inputTypes(): Array[DataType] = Array(LongType)
|
||||||
|
|
||||||
|
override def name(): String = "iavg"
|
||||||
|
|
||||||
|
override def newAggregationState(): (Long, Long) = (0L, 0L)
|
||||||
|
|
||||||
|
override def update(state: (Long, Long), input: InternalRow): (Long, Long) = {
|
||||||
|
val l = input.getLong(0)
|
||||||
|
state match {
|
||||||
|
case (_, 0L) =>
|
||||||
|
(l, 1)
|
||||||
|
case (total, count) =>
|
||||||
|
(total + l, count + 1L)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = {
|
||||||
|
(leftState._1 + rightState._1, leftState._2 + rightState._2)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
|
||||||
|
|
||||||
|
override def resultType(): DataType = IntegerType
|
||||||
|
}
|
Loading…
Reference in a new issue