diff --git a/examples/src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider b/examples/src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider new file mode 100644 index 0000000000..c239843a3b --- /dev/null +++ b/examples/src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider @@ -0,0 +1 @@ +org.apache.spark.examples.extensions.SessionExtensionsWithLoader diff --git a/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala b/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala new file mode 100644 index 0000000000..d25f220499 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.extensions + +import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates} + +/** + * How old are you in days? + */ +case class AgeExample(birthday: Expression, child: Expression) extends RuntimeReplaceable { + + def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday)) + override def exprsReplaced: Seq[Expression] = Seq(birthday) + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) +} diff --git a/examples/src/main/scala/org/apache/spark/examples/extensions/SessionExtensionsWithLoader.scala b/examples/src/main/scala/org/apache/spark/examples/extensions/SessionExtensionsWithLoader.scala new file mode 100644 index 0000000000..0daf7346bc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/extensions/SessionExtensionsWithLoader.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.extensions + +import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} + +class SessionExtensionsWithLoader extends SparkSessionExtensionsProvider { + override def apply(v1: SparkSessionExtensions): Unit = { + v1.injectFunction( + (new FunctionIdentifier("age_two"), + new ExpressionInfo(classOf[AgeExample].getName, + "age_two"), (children: Seq[Expression]) => new AgeExample(children.head))) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/extensions/SessionExtensionsWithoutLoader.scala b/examples/src/main/scala/org/apache/spark/examples/extensions/SessionExtensionsWithoutLoader.scala new file mode 100644 index 0000000000..5194c43297 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/extensions/SessionExtensionsWithoutLoader.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.extensions + +import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} + +class SessionExtensionsWithoutLoader extends SparkSessionExtensionsProvider { + override def apply(v1: SparkSessionExtensions): Unit = { + v1.injectFunction( + (new FunctionIdentifier("age_one"), + new ExpressionInfo(classOf[AgeExample].getName, + "age_one"), (children: Seq[Expression]) => new AgeExample(children.head))) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/extensions/SparkSessionExtensionsTest.scala b/examples/src/main/scala/org/apache/spark/examples/extensions/SparkSessionExtensionsTest.scala new file mode 100644 index 0000000000..8a906964a4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/extensions/SparkSessionExtensionsTest.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.extensions + +import org.apache.spark.sql.SparkSession + +/** + * [[SessionExtensionsWithLoader]] is registered in + * src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider + * + * [[SessionExtensionsWithoutLoader]] is registered via spark.sql.extensions + */ +object SparkSessionExtensionsTest { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder() + .appName("SparkSessionExtensionsTest") + .config("spark.sql.extensions", classOf[SessionExtensionsWithoutLoader].getName) + .getOrCreate() + spark.sql("SELECT age_one('2018-11-17'), age_two('2018-11-17')").show() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5768c19470..62852fe941 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import java.io.Closeable -import java.util.UUID +import java.util.{ServiceLoader, UUID} import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} @@ -949,6 +949,7 @@ object SparkSession extends Logging { // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions. } + loadExtensions(extensions) applyExtensions( sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty), extensions) @@ -1203,4 +1204,22 @@ object SparkSession extends Logging { } extensions } + + /** + * Load extensions from [[ServiceLoader]] and use them + */ + private def loadExtensions(extensions: SparkSessionExtensions): Unit = { + val loader = ServiceLoader.load(classOf[SparkSessionExtensionsProvider], + Utils.getContextOrSparkClassLoader) + val loadedExts = loader.iterator() + + while (loadedExts.hasNext) { + try { + val ext = loadedExts.next() + ext(extensions) + } catch { + case e: Throwable => logWarning("Failed to load session extension", e) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index a17832610b..b14dce64f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -71,7 +71,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} * {{{ * SparkSession.builder() * .master("...") - * .config("spark.sql.extensions", "org.example.MyExtensions") + * .config("spark.sql.extensions", "org.example.MyExtensions,org.example.YourExtensions") * .getOrCreate() * * class MyExtensions extends Function1[SparkSessionExtensions, Unit] { @@ -84,6 +84,15 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} * } * } * } + * + * class YourExtensions extends SparkSessionExtensionsProvider { + * override def apply(extensions: SparkSessionExtensions): Unit = { + * extensions.injectResolutionRule { session => + * ... + * } + * extensions.injectFunction(...) + * } + * } * }}} * * Note that none of the injected builders should assume that the [[SparkSession]] is fully diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensionsProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensionsProvider.scala new file mode 100644 index 0000000000..23f4faaa0b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensionsProvider.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.{DeveloperApi, Since, Unstable} + +// scalastyle:off line.size.limit +/** + * :: Unstable :: + * + * Base trait for implementations used by [[SparkSessionExtensions]] + * + * + * For example, now we have an external function named `Age` to register as an extension for SparkSession: + * + * + * {{{ + * package org.apache.spark.examples.extensions + * + * import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates} + * + * case class Age(birthday: Expression, child: Expression) extends RuntimeReplaceable { + * + * def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday)) + * override def exprsReplaced: Seq[Expression] = Seq(birthday) + * override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + * } + * }}} + * + * We need to create our extension which inherits [[SparkSessionExtensionsProvider]] + * Example: + * + * {{{ + * package org.apache.spark.examples.extensions + * + * import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider} + * import org.apache.spark.sql.catalyst.FunctionIdentifier + * import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} + * + * class MyExtensions extends SparkSessionExtensionsProvider { + * override def apply(v1: SparkSessionExtensions): Unit = { + * v1.injectFunction( + * (new FunctionIdentifier("age"), + * new ExpressionInfo(classOf[Age].getName, "age"), + * (children: Seq[Expression]) => new Age(children.head))) + * } + * } + * }}} + * + * Then, we can inject `MyExtensions` in three ways, + * + * + * @see [[SparkSessionExtensions]] + * @see [[SparkSession.Builder]] + * @see [[java.util.ServiceLoader]] + * + * @since 3.2.0 + */ +@DeveloperApi +@Unstable +@Since("3.2.0") +trait SparkSessionExtensionsProvider extends Function1[SparkSessionExtensions, Unit] +// scalastyle:on line.size.limit diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider new file mode 100644 index 0000000000..b5b01a09e6 --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider @@ -0,0 +1 @@ +org.apache.spark.sql.YourExtensions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index d4a6d84ce2..c8768ec2c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -46,8 +46,8 @@ import org.apache.spark.unsafe.types.UTF8String * Test cases for the [[SparkSessionExtensions]]. */ class SparkSessionExtensionSuite extends SparkFunSuite { - type ExtensionsBuilder = SparkSessionExtensions => Unit - private def create(builder: ExtensionsBuilder): Seq[ExtensionsBuilder] = Seq(builder) + private def create( + builder: SparkSessionExtensionsProvider): Seq[SparkSessionExtensionsProvider] = Seq(builder) private def stop(spark: SparkSession): Unit = { spark.stop() @@ -55,7 +55,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite { SparkSession.clearDefaultSession() } - private def withSession(builders: Seq[ExtensionsBuilder])(f: SparkSession => Unit): Unit = { + private def withSession( + builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = { val builder = SparkSession.builder().master("local[1]") builders.foreach(builder.withExtensions) val spark = builder.getOrCreate() @@ -355,6 +356,20 @@ class SparkSessionExtensionSuite extends SparkFunSuite { stop(session) } } + + test("SPARK-35380: Loading extensions from ServiceLoader") { + val builder = SparkSession.builder().master("local[1]") + + Seq(None, Some(classOf[YourExtensions].getName)).foreach { ext => + ext.foreach(builder.config(SPARK_SESSION_EXTENSIONS.key, _)) + val session = builder.getOrCreate() + try { + assert(session.sql("select get_fake_app_name()").head().getString(0) === "Fake App Name") + } finally { + stop(session) + } + } + } } case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { @@ -959,3 +974,16 @@ class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) { e.injectFunction(MyExtensions2Duplicate.myFunction) } } + +class YourExtensions extends SparkSessionExtensionsProvider { + val getAppName = (FunctionIdentifier("get_fake_app_name"), + new ExpressionInfo( + "zzz.zzz.zzz", + "", + "get_fake_app_name"), + (_: Seq[Expression]) => Literal("Fake App Name")) + + override def apply(v1: SparkSessionExtensions): Unit = { + v1.injectFunction(getAppName) + } +}