[SPARK-35380][SQL] Loading SparkSessionExtensions from ServiceLoader

### What changes were proposed in this pull request?

In https://github.com/yaooqinn/itachi/issues/8, we had a discussion about the current extension injection for the spark session.  We've agreed that the current way is not that convenient for both third-party developers and end-users.

It's much simple if third-party developers can provide a resource file that contains default extensions for Spark to  load ahead

### Why are the changes needed?

better use experience

### Does this PR introduce _any_ user-facing change?

no, dev change

### How was this patch tested?

new tests

Closes #32515 from yaooqinn/SPARK-35380.

Authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Kent Yao <yao@apache.org>
This commit is contained in:
Kent Yao 2021-05-13 16:34:13 +08:00
parent dd5464976f
commit 51815430b2
10 changed files with 276 additions and 5 deletions

View file

@ -0,0 +1 @@
org.apache.spark.examples.extensions.SessionExtensionsWithLoader

View file

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.extensions
import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates}
/**
* How old are you in days?
*/
case class AgeExample(birthday: Expression, child: Expression) extends RuntimeReplaceable {
def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday))
override def exprsReplaced: Seq[Expression] = Seq(birthday)
override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
}

View file

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.extensions
import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
class SessionExtensionsWithLoader extends SparkSessionExtensionsProvider {
override def apply(v1: SparkSessionExtensions): Unit = {
v1.injectFunction(
(new FunctionIdentifier("age_two"),
new ExpressionInfo(classOf[AgeExample].getName,
"age_two"), (children: Seq[Expression]) => new AgeExample(children.head)))
}
}

View file

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.extensions
import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
class SessionExtensionsWithoutLoader extends SparkSessionExtensionsProvider {
override def apply(v1: SparkSessionExtensions): Unit = {
v1.injectFunction(
(new FunctionIdentifier("age_one"),
new ExpressionInfo(classOf[AgeExample].getName,
"age_one"), (children: Seq[Expression]) => new AgeExample(children.head)))
}
}

View file

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.extensions
import org.apache.spark.sql.SparkSession
/**
* [[SessionExtensionsWithLoader]] is registered in
* src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider
*
* [[SessionExtensionsWithoutLoader]] is registered via spark.sql.extensions
*/
object SparkSessionExtensionsTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("SparkSessionExtensionsTest")
.config("spark.sql.extensions", classOf[SessionExtensionsWithoutLoader].getName)
.getOrCreate()
spark.sql("SELECT age_one('2018-11-17'), age_two('2018-11-17')").show()
}
}

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql
import java.io.Closeable
import java.util.UUID
import java.util.{ServiceLoader, UUID}
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
@ -949,6 +949,7 @@ object SparkSession extends Logging {
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
}
loadExtensions(extensions)
applyExtensions(
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
extensions)
@ -1203,4 +1204,22 @@ object SparkSession extends Logging {
}
extensions
}
/**
* Load extensions from [[ServiceLoader]] and use them
*/
private def loadExtensions(extensions: SparkSessionExtensions): Unit = {
val loader = ServiceLoader.load(classOf[SparkSessionExtensionsProvider],
Utils.getContextOrSparkClassLoader)
val loadedExts = loader.iterator()
while (loadedExts.hasNext) {
try {
val ext = loadedExts.next()
ext(extensions)
} catch {
case e: Throwable => logWarning("Failed to load session extension", e)
}
}
}
}

View file

@ -71,7 +71,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* {{{
* SparkSession.builder()
* .master("...")
* .config("spark.sql.extensions", "org.example.MyExtensions")
* .config("spark.sql.extensions", "org.example.MyExtensions,org.example.YourExtensions")
* .getOrCreate()
*
* class MyExtensions extends Function1[SparkSessionExtensions, Unit] {
@ -84,6 +84,15 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* }
* }
* }
*
* class YourExtensions extends SparkSessionExtensionsProvider {
* override def apply(extensions: SparkSessionExtensions): Unit = {
* extensions.injectResolutionRule { session =>
* ...
* }
* extensions.injectFunction(...)
* }
* }
* }}}
*
* Note that none of the injected builders should assume that the [[SparkSession]] is fully

View file

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.annotation.{DeveloperApi, Since, Unstable}
// scalastyle:off line.size.limit
/**
* :: Unstable ::
*
* Base trait for implementations used by [[SparkSessionExtensions]]
*
*
* For example, now we have an external function named `Age` to register as an extension for SparkSession:
*
*
* {{{
* package org.apache.spark.examples.extensions
*
* import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates}
*
* case class Age(birthday: Expression, child: Expression) extends RuntimeReplaceable {
*
* def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday))
* override def exprsReplaced: Seq[Expression] = Seq(birthday)
* override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
* }
* }}}
*
* We need to create our extension which inherits [[SparkSessionExtensionsProvider]]
* Example:
*
* {{{
* package org.apache.spark.examples.extensions
*
* import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
* import org.apache.spark.sql.catalyst.FunctionIdentifier
* import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
*
* class MyExtensions extends SparkSessionExtensionsProvider {
* override def apply(v1: SparkSessionExtensions): Unit = {
* v1.injectFunction(
* (new FunctionIdentifier("age"),
* new ExpressionInfo(classOf[Age].getName, "age"),
* (children: Seq[Expression]) => new Age(children.head)))
* }
* }
* }}}
*
* Then, we can inject `MyExtensions` in three ways,
* <ul>
* <li>withExtensions of [[SparkSession.Builder]]</li>
* <li>Config - spark.sql.extensions</li>
* <li>[[java.util.ServiceLoader]] - Add to src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider</li>
* </ul>
*
* @see [[SparkSessionExtensions]]
* @see [[SparkSession.Builder]]
* @see [[java.util.ServiceLoader]]
*
* @since 3.2.0
*/
@DeveloperApi
@Unstable
@Since("3.2.0")
trait SparkSessionExtensionsProvider extends Function1[SparkSessionExtensions, Unit]
// scalastyle:on line.size.limit

View file

@ -0,0 +1 @@
org.apache.spark.sql.YourExtensions

View file

@ -46,8 +46,8 @@ import org.apache.spark.unsafe.types.UTF8String
* Test cases for the [[SparkSessionExtensions]].
*/
class SparkSessionExtensionSuite extends SparkFunSuite {
type ExtensionsBuilder = SparkSessionExtensions => Unit
private def create(builder: ExtensionsBuilder): Seq[ExtensionsBuilder] = Seq(builder)
private def create(
builder: SparkSessionExtensionsProvider): Seq[SparkSessionExtensionsProvider] = Seq(builder)
private def stop(spark: SparkSession): Unit = {
spark.stop()
@ -55,7 +55,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
SparkSession.clearDefaultSession()
}
private def withSession(builders: Seq[ExtensionsBuilder])(f: SparkSession => Unit): Unit = {
private def withSession(
builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = {
val builder = SparkSession.builder().master("local[1]")
builders.foreach(builder.withExtensions)
val spark = builder.getOrCreate()
@ -355,6 +356,20 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
stop(session)
}
}
test("SPARK-35380: Loading extensions from ServiceLoader") {
val builder = SparkSession.builder().master("local[1]")
Seq(None, Some(classOf[YourExtensions].getName)).foreach { ext =>
ext.foreach(builder.config(SPARK_SESSION_EXTENSIONS.key, _))
val session = builder.getOrCreate()
try {
assert(session.sql("select get_fake_app_name()").head().getString(0) === "Fake App Name")
} finally {
stop(session)
}
}
}
}
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@ -959,3 +974,16 @@ class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) {
e.injectFunction(MyExtensions2Duplicate.myFunction)
}
}
class YourExtensions extends SparkSessionExtensionsProvider {
val getAppName = (FunctionIdentifier("get_fake_app_name"),
new ExpressionInfo(
"zzz.zzz.zzz",
"",
"get_fake_app_name"),
(_: Seq[Expression]) => Literal("Fake App Name"))
override def apply(v1: SparkSessionExtensions): Unit = {
v1.injectFunction(getAppName)
}
}